{
  "nbformat": 4,
  "nbformat_minor": 0,
  "metadata": {
    "colab": {
      "name": "emotic.ipynb",
      "provenance": [],
      "collapsed_sections": [],
      "authorship_tag": "ABX9TyNTRS+z3BPWqTSv2PkmiNrg",
      "include_colab_link": true
    },
    "kernelspec": {
      "name": "python3",
      "display_name": "Python 3"
    },
    "accelerator": "GPU",
    "widgets": {
      "application/vnd.jupyter.widget-state+json": {
        "a8cada3fef3846b2bffe52edacbc190d": {
          "model_module": "@jupyter-widgets/controls",
          "model_name": "HBoxModel",
          "model_module_version": "1.5.0",
          "state": {
            "_view_name": "HBoxView",
            "_dom_classes": [],
            "_model_name": "HBoxModel",
            "_view_module": "@jupyter-widgets/controls",
            "_model_module_version": "1.5.0",
            "_view_count": null,
            "_view_module_version": "1.5.0",
            "box_style": "",
            "layout": "IPY_MODEL_ae34432e333e4671b3f7f934de91027b",
            "_model_module": "@jupyter-widgets/controls",
            "children": [
              "IPY_MODEL_674e1fd300d042cbaf9f0e53e7ac4ecd",
              "IPY_MODEL_88b121247db64a3490f8c1b16f68c696",
              "IPY_MODEL_57ffb85897da4061b318afddca2eed81"
            ]
          }
        },
        "ae34432e333e4671b3f7f934de91027b": {
          "model_module": "@jupyter-widgets/base",
          "model_name": "LayoutModel",
          "model_module_version": "1.2.0",
          "state": {
            "_view_name": "LayoutView",
            "grid_template_rows": null,
            "right": null,
            "justify_content": null,
            "_view_module": "@jupyter-widgets/base",
            "overflow": null,
            "_model_module_version": "1.2.0",
            "_view_count": null,
            "flex_flow": null,
            "width": null,
            "min_width": null,
            "border": null,
            "align_items": null,
            "bottom": null,
            "_model_module": "@jupyter-widgets/base",
            "top": null,
            "grid_column": null,
            "overflow_y": null,
            "overflow_x": null,
            "grid_auto_flow": null,
            "grid_area": null,
            "grid_template_columns": null,
            "flex": null,
            "_model_name": "LayoutModel",
            "justify_items": null,
            "grid_row": null,
            "max_height": null,
            "align_content": null,
            "visibility": null,
            "align_self": null,
            "height": null,
            "min_height": null,
            "padding": null,
            "grid_auto_rows": null,
            "grid_gap": null,
            "max_width": null,
            "order": null,
            "_view_module_version": "1.2.0",
            "grid_template_areas": null,
            "object_position": null,
            "object_fit": null,
            "grid_auto_columns": null,
            "margin": null,
            "display": null,
            "left": null
          }
        },
        "674e1fd300d042cbaf9f0e53e7ac4ecd": {
          "model_module": "@jupyter-widgets/controls",
          "model_name": "HTMLModel",
          "model_module_version": "1.5.0",
          "state": {
            "_view_name": "HTMLView",
            "style": "IPY_MODEL_9cb235dbbe234dbe805b4aa00f7d54ae",
            "_dom_classes": [],
            "description": "",
            "_model_name": "HTMLModel",
            "placeholder": "​",
            "_view_module": "@jupyter-widgets/controls",
            "_model_module_version": "1.5.0",
            "value": "100%",
            "_view_count": null,
            "_view_module_version": "1.5.0",
            "description_tooltip": null,
            "_model_module": "@jupyter-widgets/controls",
            "layout": "IPY_MODEL_5e048b1fa84146c8bd2b63a19239cb9e"
          }
        },
        "88b121247db64a3490f8c1b16f68c696": {
          "model_module": "@jupyter-widgets/controls",
          "model_name": "FloatProgressModel",
          "model_module_version": "1.5.0",
          "state": {
            "_view_name": "ProgressView",
            "style": "IPY_MODEL_e6a413c0b59f466b9213c1904b1f57f8",
            "_dom_classes": [],
            "description": "",
            "_model_name": "FloatProgressModel",
            "bar_style": "success",
            "max": 46830571,
            "_view_module": "@jupyter-widgets/controls",
            "_model_module_version": "1.5.0",
            "value": 46830571,
            "_view_count": null,
            "_view_module_version": "1.5.0",
            "orientation": "horizontal",
            "min": 0,
            "description_tooltip": null,
            "_model_module": "@jupyter-widgets/controls",
            "layout": "IPY_MODEL_e0d0abfa1e9441f58722b064823c8119"
          }
        },
        "57ffb85897da4061b318afddca2eed81": {
          "model_module": "@jupyter-widgets/controls",
          "model_name": "HTMLModel",
          "model_module_version": "1.5.0",
          "state": {
            "_view_name": "HTMLView",
            "style": "IPY_MODEL_a1bbd4436c154378839f58483fa8c261",
            "_dom_classes": [],
            "description": "",
            "_model_name": "HTMLModel",
            "placeholder": "​",
            "_view_module": "@jupyter-widgets/controls",
            "_model_module_version": "1.5.0",
            "value": " 44.7M/44.7M [00:00&lt;00:00, 134MB/s]",
            "_view_count": null,
            "_view_module_version": "1.5.0",
            "description_tooltip": null,
            "_model_module": "@jupyter-widgets/controls",
            "layout": "IPY_MODEL_4f6e592ca3f34209af0ae78a635fc346"
          }
        },
        "9cb235dbbe234dbe805b4aa00f7d54ae": {
          "model_module": "@jupyter-widgets/controls",
          "model_name": "DescriptionStyleModel",
          "model_module_version": "1.5.0",
          "state": {
            "_view_name": "StyleView",
            "_model_name": "DescriptionStyleModel",
            "description_width": "",
            "_view_module": "@jupyter-widgets/base",
            "_model_module_version": "1.5.0",
            "_view_count": null,
            "_view_module_version": "1.2.0",
            "_model_module": "@jupyter-widgets/controls"
          }
        },
        "5e048b1fa84146c8bd2b63a19239cb9e": {
          "model_module": "@jupyter-widgets/base",
          "model_name": "LayoutModel",
          "model_module_version": "1.2.0",
          "state": {
            "_view_name": "LayoutView",
            "grid_template_rows": null,
            "right": null,
            "justify_content": null,
            "_view_module": "@jupyter-widgets/base",
            "overflow": null,
            "_model_module_version": "1.2.0",
            "_view_count": null,
            "flex_flow": null,
            "width": null,
            "min_width": null,
            "border": null,
            "align_items": null,
            "bottom": null,
            "_model_module": "@jupyter-widgets/base",
            "top": null,
            "grid_column": null,
            "overflow_y": null,
            "overflow_x": null,
            "grid_auto_flow": null,
            "grid_area": null,
            "grid_template_columns": null,
            "flex": null,
            "_model_name": "LayoutModel",
            "justify_items": null,
            "grid_row": null,
            "max_height": null,
            "align_content": null,
            "visibility": null,
            "align_self": null,
            "height": null,
            "min_height": null,
            "padding": null,
            "grid_auto_rows": null,
            "grid_gap": null,
            "max_width": null,
            "order": null,
            "_view_module_version": "1.2.0",
            "grid_template_areas": null,
            "object_position": null,
            "object_fit": null,
            "grid_auto_columns": null,
            "margin": null,
            "display": null,
            "left": null
          }
        },
        "e6a413c0b59f466b9213c1904b1f57f8": {
          "model_module": "@jupyter-widgets/controls",
          "model_name": "ProgressStyleModel",
          "model_module_version": "1.5.0",
          "state": {
            "_view_name": "StyleView",
            "_model_name": "ProgressStyleModel",
            "description_width": "",
            "_view_module": "@jupyter-widgets/base",
            "_model_module_version": "1.5.0",
            "_view_count": null,
            "_view_module_version": "1.2.0",
            "bar_color": null,
            "_model_module": "@jupyter-widgets/controls"
          }
        },
        "e0d0abfa1e9441f58722b064823c8119": {
          "model_module": "@jupyter-widgets/base",
          "model_name": "LayoutModel",
          "model_module_version": "1.2.0",
          "state": {
            "_view_name": "LayoutView",
            "grid_template_rows": null,
            "right": null,
            "justify_content": null,
            "_view_module": "@jupyter-widgets/base",
            "overflow": null,
            "_model_module_version": "1.2.0",
            "_view_count": null,
            "flex_flow": null,
            "width": null,
            "min_width": null,
            "border": null,
            "align_items": null,
            "bottom": null,
            "_model_module": "@jupyter-widgets/base",
            "top": null,
            "grid_column": null,
            "overflow_y": null,
            "overflow_x": null,
            "grid_auto_flow": null,
            "grid_area": null,
            "grid_template_columns": null,
            "flex": null,
            "_model_name": "LayoutModel",
            "justify_items": null,
            "grid_row": null,
            "max_height": null,
            "align_content": null,
            "visibility": null,
            "align_self": null,
            "height": null,
            "min_height": null,
            "padding": null,
            "grid_auto_rows": null,
            "grid_gap": null,
            "max_width": null,
            "order": null,
            "_view_module_version": "1.2.0",
            "grid_template_areas": null,
            "object_position": null,
            "object_fit": null,
            "grid_auto_columns": null,
            "margin": null,
            "display": null,
            "left": null
          }
        },
        "a1bbd4436c154378839f58483fa8c261": {
          "model_module": "@jupyter-widgets/controls",
          "model_name": "DescriptionStyleModel",
          "model_module_version": "1.5.0",
          "state": {
            "_view_name": "StyleView",
            "_model_name": "DescriptionStyleModel",
            "description_width": "",
            "_view_module": "@jupyter-widgets/base",
            "_model_module_version": "1.5.0",
            "_view_count": null,
            "_view_module_version": "1.2.0",
            "_model_module": "@jupyter-widgets/controls"
          }
        },
        "4f6e592ca3f34209af0ae78a635fc346": {
          "model_module": "@jupyter-widgets/base",
          "model_name": "LayoutModel",
          "model_module_version": "1.2.0",
          "state": {
            "_view_name": "LayoutView",
            "grid_template_rows": null,
            "right": null,
            "justify_content": null,
            "_view_module": "@jupyter-widgets/base",
            "overflow": null,
            "_model_module_version": "1.2.0",
            "_view_count": null,
            "flex_flow": null,
            "width": null,
            "min_width": null,
            "border": null,
            "align_items": null,
            "bottom": null,
            "_model_module": "@jupyter-widgets/base",
            "top": null,
            "grid_column": null,
            "overflow_y": null,
            "overflow_x": null,
            "grid_auto_flow": null,
            "grid_area": null,
            "grid_template_columns": null,
            "flex": null,
            "_model_name": "LayoutModel",
            "justify_items": null,
            "grid_row": null,
            "max_height": null,
            "align_content": null,
            "visibility": null,
            "align_self": null,
            "height": null,
            "min_height": null,
            "padding": null,
            "grid_auto_rows": null,
            "grid_gap": null,
            "max_width": null,
            "order": null,
            "_view_module_version": "1.2.0",
            "grid_template_areas": null,
            "object_position": null,
            "object_fit": null,
            "grid_auto_columns": null,
            "margin": null,
            "display": null,
            "left": null
          }
        }
      }
    }
  },
  "cells": [
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "view-in-github",
        "colab_type": "text"
      },
      "source": [
        "<a href=\"https://colab.research.google.com/github/Tandon-A/emotic/blob/master/Colab_train_emotic.ipynb\" target=\"_parent\"><img src=\"https://colab.research.google.com/assets/colab-badge.svg\" alt=\"Open In Colab\"/></a>"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "_5Xan2tnR89K"
      },
      "source": [
        "<h1><center> Emotions in context (Emotic) </center></h1>\n",
        "<center> Using context information to recognize emotions in images</center>"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "rbCWI0rkt8yp"
      },
      "source": [
        "<h1>Project context</h1>\n",
        "\n",
        "Humans use their facial features or expressions to convey how they feel, such as a person may smile when happy and scowl when angry. Historically, computer vision research has focussed on analyzing and learning these facial features to recognize emotions. \n",
        "However, these facial features are not universal and vary extensively across cultures and situations. \n",
        "\n",
        "<figure>\n",
        "<img src=\"https://raw.githubusercontent.com/Tandon-A/emotic/master/assets/face.jpg\"> <img src=\"https://raw.githubusercontent.com/Tandon-A/emotic/master/assets/full_scene.jpg\" width=\"400\">\n",
        "  <figcaption>Fig 1: a) (Facial feature) The person looks angry or in pain b) (Whole scene) The person looks elated.</figcaption>\n",
        "</figure>\n",
        "\n",
        "\n",
        "A scene context, as shown in the figure above, can provide additional information about the situations. This project explores the use of context in recognizing emotions in images. \n",
        "\n",
        "This project uses the <a href=\"http://sunai.uoc.edu/emotic/download.html\">EMOTIC dataset</a> and follows the methodology as introduced in the paper <a href=\"https://arxiv.org/pdf/2003.13401.pdf\">'Context based emotion recognition using EMOTIC dataset'</a>."
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "1YFaW8HlNWnE",
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "outputId": "7cc564d6-4503-4b5a-bac8-a4fe0bdbcb65"
      },
      "source": [
        "# Linking Google drive to use preprocessed data \n",
        "from google.colab import drive\n",
        "\n",
        "# This will prompt for authorization.\n",
        "drive.mount('/content/drive')\n",
        "#/content/drive/My Drive//"
      ],
      "execution_count": 1,
      "outputs": [
        {
          "output_type": "stream",
          "text": [
            "Mounted at /content/drive\n"
          ],
          "name": "stdout"
        }
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "FhzX7KUihZqu"
      },
      "source": [
        "# I. Prepare places pretrained model"
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "uYgeeri3wdCM",
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "outputId": "59be98ac-4cc9-403c-e116-bac36e368e8b"
      },
      "source": [
        "# Get Resnet18 model trained on places dataset. \n",
        "!mkdir ./places\n",
        "!wget http://places2.csail.mit.edu/models_places365/resnet18_places365.pth.tar -O ./places/resnet18_places365.pth.tar"
      ],
      "execution_count": 2,
      "outputs": [
        {
          "output_type": "stream",
          "text": [
            "--2021-08-17 17:32:18--  http://places2.csail.mit.edu/models_places365/resnet18_places365.pth.tar\n",
            "Resolving places2.csail.mit.edu (places2.csail.mit.edu)... 128.30.195.26\n",
            "Connecting to places2.csail.mit.edu (places2.csail.mit.edu)|128.30.195.26|:80... connected.\n",
            "HTTP request sent, awaiting response... 200 OK\n",
            "Length: 45506139 (43M) [application/x-tar]\n",
            "Saving to: ‘./places/resnet18_places365.pth.tar’\n",
            "\n",
            "./places/resnet18_p 100%[===================>]  43.40M  24.3MB/s    in 1.8s    \n",
            "\n",
            "2021-08-17 17:32:20 (24.3 MB/s) - ‘./places/resnet18_places365.pth.tar’ saved [45506139/45506139]\n",
            "\n"
          ],
          "name": "stdout"
        }
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "RhWL6Qi_w4qp",
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "outputId": "4803750e-9487-4589-ef86-d8244ed698ca"
      },
      "source": [
        "# Saving the model weights to use ahead in the notebook\n",
        "import torch\n",
        "from torch.autograd import Variable as V\n",
        "import torchvision.models as models\n",
        "from PIL import Image\n",
        "from torchvision import transforms as trn\n",
        "from torch.nn import functional as F\n",
        "import os\n",
        "\n",
        "# the architecture to use\n",
        "arch = 'resnet18'\n",
        "model_weight = os.path.join('./places', 'resnet18_places365.pth.tar')\n",
        "\n",
        "# create the network architecture\n",
        "model = models.__dict__[arch](num_classes=365)\n",
        "\n",
        "#model_weight = '%s_places365.pth.tar' % arch\n",
        "\n",
        "checkpoint = torch.load(model_weight, map_location=lambda storage, loc: storage) # model trained in GPU could be deployed in CPU machine like this!\n",
        "state_dict = {str.replace(k,'module.',''): v for k,v in checkpoint['state_dict'].items()} # the data parallel layer will add 'module' before each layer name\n",
        "model.load_state_dict(state_dict)\n",
        "model.eval()\n",
        "\n",
        "model.cpu()\n",
        "torch.save(model.state_dict(), './places/resnet18_state_dict.pth')\n",
        "print ('completed cell')"
      ],
      "execution_count": 3,
      "outputs": [
        {
          "output_type": "stream",
          "text": [
            "completed cell\n"
          ],
          "name": "stdout"
        }
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "ykNjfrUuhpbq"
      },
      "source": [
        "# II. General imports"
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "vi-O8QgwvOQY",
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "outputId": "6f5857a3-f3af-4dbb-dd7f-8539fab5b9e7"
      },
      "source": [
        "import matplotlib.pyplot as plt\n",
        "import numpy as np\n",
        "import os\n",
        "from PIL import Image\n",
        "import scipy.io\n",
        "from sklearn.metrics import average_precision_score, precision_recall_curve\n",
        "\n",
        "import torch \n",
        "import torch.nn as nn \n",
        "import torch.nn.functional as F\n",
        "import torch.optim as optim \n",
        "from torch.utils.data import Dataset, DataLoader \n",
        "from torchsummary import summary\n",
        "from torchvision import transforms\n",
        "import torchvision.models as models\n",
        "from torch.optim.lr_scheduler import StepLR\n",
        "\n",
        "print ('completed cell')"
      ],
      "execution_count": 4,
      "outputs": [
        {
          "output_type": "stream",
          "text": [
            "completed cell\n"
          ],
          "name": "stdout"
        }
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "AD0pBBBYh2vW"
      },
      "source": [
        "# III. Emotic classes"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "ZfPKerg4TWkR"
      },
      "source": [
        "## Emotic Model "
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "ZWt88EcJVu0c",
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "outputId": "cd2365da-0d45-4616-800c-e1fd6f565c29"
      },
      "source": [
        "class Emotic(nn.Module):\n",
        "  ''' Emotic Model'''\n",
        "  def __init__(self, num_context_features, num_body_features):\n",
        "    super(Emotic,self).__init__()\n",
        "    self.num_context_features = num_context_features\n",
        "    self.num_body_features = num_body_features\n",
        "    self.fc1 = nn.Linear((self.num_context_features + num_body_features), 256)\n",
        "    self.bn1 = nn.BatchNorm1d(256)\n",
        "    self.d1 = nn.Dropout(p=0.5)\n",
        "    self.fc_cat = nn.Linear(256, 26)\n",
        "    self.fc_cont = nn.Linear(256, 3)\n",
        "    self.relu = nn.ReLU()\n",
        "\n",
        "    \n",
        "  def forward(self, x_context, x_body):\n",
        "    context_features = x_context.view(-1, self.num_context_features)\n",
        "    body_features = x_body.view(-1, self.num_body_features)\n",
        "    fuse_features = torch.cat((context_features, body_features), 1)\n",
        "    fuse_out = self.fc1(fuse_features)\n",
        "    fuse_out = self.bn1(fuse_out)\n",
        "    fuse_out = self.relu(fuse_out)\n",
        "    fuse_out = self.d1(fuse_out)    \n",
        "    cat_out = self.fc_cat(fuse_out)\n",
        "    cont_out = self.fc_cont(fuse_out)\n",
        "    return cat_out, cont_out\n",
        "\n",
        "print ('completed cell')"
      ],
      "execution_count": 5,
      "outputs": [
        {
          "output_type": "stream",
          "text": [
            "completed cell\n"
          ],
          "name": "stdout"
        }
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "zdzZGj6AxLaC"
      },
      "source": [
        "## Emotic Dataset"
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "eKG5dNMXxlnm",
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "outputId": "890ab105-8973-4be7-be1f-670a816d6b79"
      },
      "source": [
        "class Emotic_PreDataset(Dataset):\n",
        "  ''' Custom Emotic dataset class. Use preprocessed data stored in npy files. '''\n",
        "  def __init__(self, x_context, x_body, y_cat, y_cont, transform, context_norm, body_norm):\n",
        "    super(Emotic_PreDataset,self).__init__()\n",
        "    self.x_context = x_context\n",
        "    self.x_body = x_body\n",
        "    self.y_cat = y_cat \n",
        "    self.y_cont = y_cont\n",
        "    self.transform = transform \n",
        "    self.context_norm = transforms.Normalize(context_norm[0], context_norm[1])  # Normalizing the context image with context mean and context std\n",
        "    self.body_norm = transforms.Normalize(body_norm[0], body_norm[1])           # Normalizing the body image with body mean and body std\n",
        "\n",
        "  def __len__(self):\n",
        "    return len(self.y_cat)\n",
        "  \n",
        "  def __getitem__(self, index):\n",
        "    image_context = self.x_context[index]\n",
        "    image_body = self.x_body[index]\n",
        "    cat_label = self.y_cat[index]\n",
        "    cont_label = self.y_cont[index]\n",
        "    return self.context_norm(self.transform(image_context)), self.body_norm(self.transform(image_body)), torch.tensor(cat_label, dtype=torch.float32), torch.tensor(cont_label, dtype=torch.float32)/10.0\n",
        "\n",
        "print ('completed cell')"
      ],
      "execution_count": 6,
      "outputs": [
        {
          "output_type": "stream",
          "text": [
            "completed cell\n"
          ],
          "name": "stdout"
        }
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "JFuEQruAxQrK"
      },
      "source": [
        "## Emotic Losses"
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "ObffJVXkqsJg",
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "outputId": "9665ef7f-44a7-4ddf-db6f-a4a0e6430061"
      },
      "source": [
        "class DiscreteLoss(nn.Module):\n",
        "  ''' Class to measure loss between categorical emotion predictions and labels.'''\n",
        "  def __init__(self, weight_type='mean', device=torch.device('cpu')):\n",
        "    super(DiscreteLoss, self).__init__()\n",
        "    self.weight_type = weight_type\n",
        "    self.device = device\n",
        "    if self.weight_type == 'mean':\n",
        "      self.weights = torch.ones((1,26))/26.0\n",
        "      self.weights = self.weights.to(self.device)\n",
        "    elif self.weight_type == 'static':\n",
        "      self.weights = torch.FloatTensor([0.1435, 0.1870, 0.1692, 0.1165, 0.1949, 0.1204, 0.1728, 0.1372, 0.1620,\n",
        "         0.1540, 0.1987, 0.1057, 0.1482, 0.1192, 0.1590, 0.1929, 0.1158, 0.1907,\n",
        "         0.1345, 0.1307, 0.1665, 0.1698, 0.1797, 0.1657, 0.1520, 0.1537]).unsqueeze(0)\n",
        "      self.weights = self.weights.to(self.device)\n",
        "    \n",
        "  def forward(self, pred, target):\n",
        "    if self.weight_type == 'dynamic':\n",
        "      self.weights = self.prepare_dynamic_weights(target)\n",
        "      self.weights = self.weights.to(self.device)\n",
        "    loss = (((pred - target)**2) * self.weights)\n",
        "    return loss.sum() \n",
        "\n",
        "  def prepare_dynamic_weights(self, target):\n",
        "    target_stats = torch.sum(target, dim=0).float().unsqueeze(dim=0).cpu()\n",
        "    weights = torch.zeros((1,26))\n",
        "    weights[target_stats != 0 ] = 1.0/torch.log(target_stats[target_stats != 0].data + 1.2)\n",
        "    weights[target_stats == 0] = 0.0001\n",
        "    return weights\n",
        "\n",
        "\n",
        "class ContinuousLoss_L2(nn.Module):\n",
        "  ''' Class to measure loss between continuous emotion dimension predictions and labels. Using l2 loss as base. '''\n",
        "  def __init__(self, margin=1):\n",
        "    super(ContinuousLoss_L2, self).__init__()\n",
        "    self.margin = margin\n",
        "  \n",
        "  def forward(self, pred, target):\n",
        "    labs = torch.abs(pred - target)\n",
        "    loss = labs ** 2 \n",
        "    loss[ (labs < self.margin) ] = 0.0\n",
        "    return loss.sum()\n",
        "\n",
        "\n",
        "class ContinuousLoss_SL1(nn.Module):\n",
        "  ''' Class to measure loss between continuous emotion dimension predictions and labels. Using smooth l1 loss as base. '''\n",
        "  def __init__(self, margin=1):\n",
        "    super(ContinuousLoss_SL1, self).__init__()\n",
        "    self.margin = margin\n",
        "  \n",
        "  def forward(self, pred, target):\n",
        "    labs = torch.abs(pred - target)\n",
        "    loss = 0.5 * (labs ** 2)\n",
        "    loss[ (labs > self.margin) ] = labs[ (labs > self.margin) ] - 0.5\n",
        "    return loss.sum()\n",
        "\n",
        "print ('completed cell')"
      ],
      "execution_count": 7,
      "outputs": [
        {
          "output_type": "stream",
          "text": [
            "completed cell\n"
          ],
          "name": "stdout"
        }
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "-AMUYcy5h9cM"
      },
      "source": [
        "# IV. Load preprocessed data"
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "VSadne_Bc5va",
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "outputId": "cea63663-6140-4666-8a80-3e69434b92d6"
      },
      "source": [
        "# Change data_src variable as per your drive\n",
        "data_src = '/content/drive/My Drive/Colab/Emotic/data'\n",
        "\n",
        "\n",
        "# Load training preprocessed data\n",
        "train_context = np.load(os.path.join(data_src,'pre','train_context_arr.npy'))\n",
        "train_body = np.load(os.path.join(data_src,'pre','train_body_arr.npy'))\n",
        "train_cat = np.load(os.path.join(data_src,'pre','train_cat_arr.npy'))\n",
        "train_cont = np.load(os.path.join(data_src,'pre','train_cont_arr.npy'))\n",
        "\n",
        "# Load validation preprocessed data \n",
        "val_context = np.load(os.path.join(data_src,'pre','val_context_arr.npy'))\n",
        "val_body = np.load(os.path.join(data_src,'pre','val_body_arr.npy'))\n",
        "val_cat = np.load(os.path.join(data_src,'pre','val_cat_arr.npy'))\n",
        "val_cont = np.load(os.path.join(data_src,'pre','val_cont_arr.npy'))\n",
        "\n",
        "# Load testing preprocessed data\n",
        "test_context = np.load(os.path.join(data_src,'pre','test_context_arr.npy'))\n",
        "test_body = np.load(os.path.join(data_src,'pre','test_body_arr.npy'))\n",
        "test_cat = np.load(os.path.join(data_src,'pre','test_cat_arr.npy'))\n",
        "test_cont = np.load(os.path.join(data_src,'pre','test_cont_arr.npy'))\n",
        "\n",
        "# Categorical emotion classes\n",
        "cat = ['Affection', 'Anger', 'Annoyance', 'Anticipation', 'Aversion', 'Confidence', 'Disapproval', 'Disconnection',\n",
        "       'Disquietment', 'Doubt/Confusion', 'Embarrassment', 'Engagement', 'Esteem', 'Excitement', 'Fatigue', 'Fear',\n",
        "       'Happiness', 'Pain', 'Peace', 'Pleasure', 'Sadness', 'Sensitivity', 'Suffering', 'Surprise', 'Sympathy', 'Yearning']\n",
        "\n",
        "cat2ind = {}\n",
        "ind2cat = {}\n",
        "for idx, emotion in enumerate(cat):\n",
        "  cat2ind[emotion] = idx\n",
        "  ind2cat[idx] = emotion\n",
        "\n",
        "print ('train ', 'context ', train_context.shape, 'body', train_body.shape, 'cat ', train_cat.shape, 'cont', train_cont.shape)\n",
        "print ('val ', 'context ', val_context.shape, 'body', val_body.shape, 'cat ', val_cat.shape, 'cont', val_cont.shape)\n",
        "print ('test ', 'context ', test_context.shape, 'body', test_body.shape, 'cat ', test_cat.shape, 'cont', test_cont.shape)\n",
        "print ('completed cell')"
      ],
      "execution_count": 8,
      "outputs": [
        {
          "output_type": "stream",
          "text": [
            "train  context  (23266, 224, 224, 3) body (23266, 128, 128, 3) cat  (23266, 26) cont (23266, 3)\n",
            "val  context  (3315, 224, 224, 3) body (3315, 128, 128, 3) cat  (3315, 26) cont (3315, 3)\n",
            "test  context  (7203, 224, 224, 3) body (7203, 128, 128, 3) cat  (7203, 26) cont (7203, 3)\n",
            "completed cell\n"
          ],
          "name": "stdout"
        }
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "JySFyUFZNgPy",
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "outputId": "84ba41f4-7fee-466e-b1b1-ff2fce976395"
      },
      "source": [
        "batch_size = 26\n",
        "\n",
        "context_mean = [0.4690646, 0.4407227, 0.40508908]\n",
        "context_std = [0.2514227, 0.24312855, 0.24266963]\n",
        "body_mean = [0.43832874, 0.3964344, 0.3706214]\n",
        "body_std = [0.24784276, 0.23621225, 0.2323653]\n",
        "context_norm = [context_mean, context_std]\n",
        "body_norm = [body_mean, body_std]\n",
        "\n",
        "\n",
        "train_transform = transforms.Compose([transforms.ToPILImage(), \n",
        "                                      transforms.RandomHorizontalFlip(), \n",
        "                                      transforms.ColorJitter(brightness=0.4, contrast=0.4, saturation=0.4), \n",
        "                                      transforms.ToTensor()])\n",
        "test_transform = transforms.Compose([transforms.ToPILImage(), \n",
        "                                     transforms.ToTensor()])\n",
        "\n",
        "train_dataset = Emotic_PreDataset(train_context, train_body, train_cat, train_cont, \\\n",
        "                                  train_transform, context_norm, body_norm)\n",
        "val_dataset = Emotic_PreDataset(val_context, val_body, val_cat, val_cont, \\\n",
        "                                test_transform, context_norm, body_norm)\n",
        "test_dataset = Emotic_PreDataset(test_context, test_body, test_cat, test_cont, \\\n",
        "                                 test_transform, context_norm, body_norm)\n",
        "\n",
        "train_loader = DataLoader(train_dataset, batch_size, shuffle=True, drop_last=True)\n",
        "val_loader = DataLoader(val_dataset, batch_size, shuffle=False)\n",
        "test_loader = DataLoader(test_dataset, batch_size, shuffle=False) \n",
        "\n",
        "print ('train loader ', len(train_loader), 'val loader ', len(val_loader), 'test', len(test_loader))\n",
        "print ('completed cell')"
      ],
      "execution_count": 9,
      "outputs": [
        {
          "output_type": "stream",
          "text": [
            "train loader  894 val loader  128 test 278\n",
            "completed cell\n"
          ],
          "name": "stdout"
        }
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "wvPoFnAliZBC"
      },
      "source": [
        "# V. Prepare emotic model"
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "cMSaPqJyVyEW",
        "colab": {
          "base_uri": "https://localhost:8080/",
          "height": 83,
          "referenced_widgets": [
            "a8cada3fef3846b2bffe52edacbc190d",
            "ae34432e333e4671b3f7f934de91027b",
            "674e1fd300d042cbaf9f0e53e7ac4ecd",
            "88b121247db64a3490f8c1b16f68c696",
            "57ffb85897da4061b318afddca2eed81",
            "9cb235dbbe234dbe805b4aa00f7d54ae",
            "5e048b1fa84146c8bd2b63a19239cb9e",
            "e6a413c0b59f466b9213c1904b1f57f8",
            "e0d0abfa1e9441f58722b064823c8119",
            "a1bbd4436c154378839f58483fa8c261",
            "4f6e592ca3f34209af0ae78a635fc346"
          ]
        },
        "outputId": "b1b68154-bcfc-438a-c711-31b84177d56c"
      },
      "source": [
        "model_path_places = './places'\n",
        "\n",
        "model_context = models.__dict__[arch](num_classes=365)\n",
        "context_state_dict = torch.load(os.path.join(model_path_places, 'resnet18_state_dict.pth'))\n",
        "model_context.load_state_dict(context_state_dict)\n",
        "\n",
        "model_body = models.resnet18(pretrained=True)\n",
        "\n",
        "emotic_model = Emotic(list(model_context.children())[-1].in_features, list(model_body.children())[-1].in_features)\n",
        "model_context = nn.Sequential(*(list(model_context.children())[:-1]))\n",
        "model_body = nn.Sequential(*(list(model_body.children())[:-1]))\n",
        "\n",
        "\n",
        "# print (summary(model_context, (3,224,224), device=\"cpu\"))\n",
        "# print (summary(model_body, (3,128,128), device=\"cpu\"))\n",
        "\n",
        "print ('completed cell')"
      ],
      "execution_count": 10,
      "outputs": [
        {
          "output_type": "stream",
          "text": [
            "Downloading: \"https://download.pytorch.org/models/resnet18-f37072fd.pth\" to /root/.cache/torch/hub/checkpoints/resnet18-f37072fd.pth\n"
          ],
          "name": "stderr"
        },
        {
          "output_type": "display_data",
          "data": {
            "application/vnd.jupyter.widget-view+json": {
              "model_id": "a8cada3fef3846b2bffe52edacbc190d",
              "version_minor": 0,
              "version_major": 2
            },
            "text/plain": [
              "  0%|          | 0.00/44.7M [00:00<?, ?B/s]"
            ]
          },
          "metadata": {
            "tags": []
          }
        },
        {
          "output_type": "stream",
          "text": [
            "completed cell\n"
          ],
          "name": "stdout"
        }
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "rE5qh_ljPOqs"
      },
      "source": [
        "## Prepare optimizer"
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "I6-3FTclWAGh",
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "outputId": "e28c5abf-7123-4d7f-d7c6-816884411f29"
      },
      "source": [
        "for param in emotic_model.parameters():\n",
        "  param.requires_grad = True\n",
        "for param in model_context.parameters():\n",
        "  param.requires_grad = False\n",
        "for param in model_body.parameters():\n",
        "  param.requires_grad = False\n",
        "\n",
        "device = torch.device(\"cuda:0\" if torch.cuda.is_available() else \"cpu\")\n",
        "opt = optim.Adam((list(emotic_model.parameters()) + list(model_context.parameters()) + \\\n",
        "                  list(model_body.parameters())), lr=0.001, weight_decay=5e-4)\n",
        "scheduler = StepLR(opt, step_size=7, gamma=0.1)\n",
        "\n",
        "disc_loss = DiscreteLoss('dynamic', device)\n",
        "cont_loss_SL1 = ContinuousLoss_SL1()\n",
        "\n",
        "print ('completed cell')"
      ],
      "execution_count": 11,
      "outputs": [
        {
          "output_type": "stream",
          "text": [
            "completed cell\n"
          ],
          "name": "stdout"
        }
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "hvUH2QxGjCEc"
      },
      "source": [
        "# VI. Train model"
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "wqtB3MrzA3Uj",
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "outputId": "28fced80-07fc-4b30-b00e-beb19c096fd5"
      },
      "source": [
        "def train_emotic(epochs, model_path, opt, scheduler, models, disc_loss, cont_loss, cat_loss_param=0.5, cont_loss_param=0.5):\n",
        "  if not os.path.exists(model_path):\n",
        "    os.makedirs(model_path)\n",
        "  \n",
        "  min_loss = np.inf\n",
        "\n",
        "  train_loss = list()\n",
        "  val_loss = list()\n",
        "\n",
        "  model_context, model_body, emotic_model = models\n",
        "\n",
        "  for e in range(epochs):\n",
        "    running_loss = 0.0\n",
        "\n",
        "    emotic_model.to(device)\n",
        "    model_context.to(device)\n",
        "    model_body.to(device)\n",
        "    \n",
        "    emotic_model.train()\n",
        "    model_context.train()\n",
        "    model_body.train()\n",
        "    \n",
        "    for images_context, images_body, labels_cat, labels_cont in iter(train_loader):\n",
        "      images_context = images_context.to(device)\n",
        "      images_body = images_body.to(device)\n",
        "      labels_cat = labels_cat.to(device)\n",
        "      labels_cont = labels_cont.to(device)\n",
        "\n",
        "      opt.zero_grad()\n",
        "\n",
        "      pred_context = model_context(images_context)\n",
        "      pred_body = model_body(images_body)\n",
        "\n",
        "      pred_cat, pred_cont = emotic_model(pred_context, pred_body)\n",
        "      cat_loss_batch = disc_loss(pred_cat, labels_cat)\n",
        "      cont_loss_batch = cont_loss(pred_cont * 10, labels_cont * 10)\n",
        "      loss = (cat_loss_param * cat_loss_batch) + (cont_loss_param * cont_loss_batch)\n",
        "      running_loss += loss.item()\n",
        "      loss.backward()\n",
        "      opt.step()\n",
        "\n",
        "    if e % 1 == 0: \n",
        "      print ('epoch = %d training loss = %.4f' %(e, running_loss))\n",
        "    train_loss.append(running_loss)\n",
        "\n",
        "    \n",
        "    running_loss = 0.0 \n",
        "    emotic_model.eval()\n",
        "    model_context.eval()\n",
        "    model_body.eval()\n",
        "    \n",
        "    with torch.no_grad():\n",
        "      for images_context, images_body, labels_cat, labels_cont in iter(val_loader):\n",
        "        images_context = images_context.to(device)\n",
        "        images_body = images_body.to(device)\n",
        "        labels_cat = labels_cat.to(device)\n",
        "        labels_cont = labels_cont.to(device)\n",
        "\n",
        "        pred_context = model_context(images_context)\n",
        "        pred_body = model_body(images_body)\n",
        "        \n",
        "        pred_cat, pred_cont = emotic_model(pred_context, pred_body)\n",
        "        cat_loss_batch = disc_loss(pred_cat, labels_cat)\n",
        "        cont_loss_batch = cont_loss(pred_cont * 10, labels_cont * 10)\n",
        "        loss = (cat_loss_param * cat_loss_batch) + (cont_loss_param * cont_loss_batch)\n",
        "        running_loss += loss.item()\n",
        "\n",
        "      if e % 1 == 0:\n",
        "        print ('epoch = %d validation loss = %.4f' %(e, running_loss))\n",
        "    val_loss.append(running_loss)\n",
        "      \n",
        "    scheduler.step()\n",
        "\n",
        "    if val_loss[-1] < min_loss:\n",
        "        min_loss = val_loss[-1]\n",
        "        # saving models for lowest loss\n",
        "        print ('saving model at epoch e = %d' %(e))\n",
        "        emotic_model.to(\"cpu\")\n",
        "        model_context.to(\"cpu\")\n",
        "        model_body.to(\"cpu\")\n",
        "        torch.save(emotic_model, os.path.join(model_path, 'model_emotic1.pth'))\n",
        "        torch.save(model_context, os.path.join(model_path, 'model_context1.pth'))\n",
        "        torch.save(model_body, os.path.join(model_path, 'model_body1.pth'))\n",
        "\n",
        "  print ('completed training')\n",
        "  \n",
        "  f, (ax1, ax2) = plt.subplots(1, 2, figsize = (6, 6))\n",
        "  f.suptitle('emotic')\n",
        "  ax1.plot(range(0,len(train_loss)),train_loss, color='Blue')\n",
        "  ax2.plot(range(0,len(val_loss)),val_loss, color='Red')\n",
        "  ax1.legend(['train'])\n",
        "  ax2.legend(['val'])\n",
        "\n",
        "print ('completed cell')"
      ],
      "execution_count": 12,
      "outputs": [
        {
          "output_type": "stream",
          "text": [
            "completed cell\n"
          ],
          "name": "stdout"
        }
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "i1KsKv_hwoUC",
        "colab": {
          "base_uri": "https://localhost:8080/",
          "height": 1000
        },
        "outputId": "9a9ba991-6865-474b-e396-21ce71daec09"
      },
      "source": [
        "train_emotic(15, './models', opt, scheduler, [model_context, model_body, emotic_model], disc_loss, cont_loss_SL1)"
      ],
      "execution_count": 13,
      "outputs": [
        {
          "output_type": "stream",
          "text": [
            "/usr/local/lib/python3.7/dist-packages/torch/nn/functional.py:718: UserWarning: Named tensors and all their associated APIs are an experimental feature and subject to change. Please do not use them for anything important until they are released as stable. (Triggered internally at  /pytorch/c10/core/TensorImpl.h:1156.)\n",
            "  return torch.max_pool2d(input, kernel_size, stride, padding, dilation, ceil_mode)\n"
          ],
          "name": "stderr"
        },
        {
          "output_type": "stream",
          "text": [
            "epoch = 0 training loss = 64863.3032\n",
            "epoch = 0 validation loss = 6286.3406\n",
            "saving model at epoch e = 0\n",
            "epoch = 1 training loss = 47588.2536\n",
            "epoch = 1 validation loss = 5754.6668\n",
            "saving model at epoch e = 1\n",
            "epoch = 2 training loss = 45589.5908\n",
            "epoch = 2 validation loss = 5753.7185\n",
            "saving model at epoch e = 2\n",
            "epoch = 3 training loss = 44254.9966\n",
            "epoch = 3 validation loss = 5677.1488\n",
            "saving model at epoch e = 3\n",
            "epoch = 4 training loss = 43592.7128\n",
            "epoch = 4 validation loss = 5722.5824\n",
            "epoch = 5 training loss = 42938.6911\n",
            "epoch = 5 validation loss = 5569.4062\n",
            "saving model at epoch e = 5\n",
            "epoch = 6 training loss = 42579.4593\n",
            "epoch = 6 validation loss = 5796.7644\n",
            "epoch = 7 training loss = 41934.0057\n",
            "epoch = 7 validation loss = 5575.4900\n",
            "epoch = 8 training loss = 41629.1832\n",
            "epoch = 8 validation loss = 5591.6801\n",
            "epoch = 9 training loss = 41416.1661\n",
            "epoch = 9 validation loss = 5536.8152\n",
            "saving model at epoch e = 9\n",
            "epoch = 10 training loss = 41335.3265\n",
            "epoch = 10 validation loss = 5505.6411\n",
            "saving model at epoch e = 10\n",
            "epoch = 11 training loss = 41142.7416\n",
            "epoch = 11 validation loss = 5551.4237\n",
            "epoch = 12 training loss = 41124.7236\n",
            "epoch = 12 validation loss = 5520.3929\n",
            "epoch = 13 training loss = 41033.9440\n",
            "epoch = 13 validation loss = 5579.3154\n",
            "epoch = 14 training loss = 40919.0560\n",
            "epoch = 14 validation loss = 5558.6090\n",
            "completed training\n"
          ],
          "name": "stdout"
        },
        {
          "output_type": "display_data",
          "data": {
            "image/png": "iVBORw0KGgoAAAANSUhEUgAAAYMAAAGQCAYAAABMJgwnAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADh0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uMy4yLjIsIGh0dHA6Ly9tYXRwbG90bGliLm9yZy+WH4yJAAAgAElEQVR4nO3dfZyUdb3/8deH5f5GuVVARMDwBjARVrLjTXkTopZoilCpZCaeE1p2UoM6v/RoVJodb05qoZI3aUiYSYUiWpbnHFGXpATEWEFlUWEVQQNlWfj8/vheI8Myy87sXLOzc837+XjMY2e+1818B66Zz/W9N3dHRETKW5tiZ0BERIpPwUBERBQMREREwUBERFAwEBERFAxERAQFA5EWY2aPmtnkYudDJBPTOAOR+JnZ1cDH3P3cYudFJBsqGYiIiIKBlB8z629mD5lZrZmtNrOvR+lXm9mvzeyXZva+mb1oZgeZ2XQzW29ma8xsbIPzzDOzDWZWbWYXRenjgO8AE83sn2b2tyj9KTP7atrxF5nZS9F7LTezUS37LyGyk4KBlBUzawP8DvgbsB9wInCZmZ0c7fI54D6gB/ACsIDwPdkPuAb4edrpZgM1QH/gbOAHZnaCuz8G/AB40N27uvvhGfIxAbgaOB/YCzgdeCfWDyuSAwUDKTdHAn3c/Rp3r3P3VcAdwKRo+9PuvsDd64FfA32AH7n7NsKP/yAz625m+wNHA9929w/dfQlwJ+HHPRtfBa539+c9qHb312L8nCI5aVvsDIi0sAOA/ma2MS2tAngaeA1Yl5b+AfC2u29Pew3QlVAa2ODu76ft/xpQmWU+9gdeyTHvIgWjkoGUmzXAanfvnvbo5u6n5nieN4CeZtYtLW0gsDZ63lQ3vTXAgTm+p0jBKBhIuXkOeN/Mvm1mncyswsxGmNmRuZzE3dcA/wf80Mw6mtnHgQuBX0a7rCNUKTX2HbsTuNzMRlvwMTM7oJmfSSRvCgZSVqIqn88CI4HVwNuEH+a9m3G6LwCDCKWEh4Gr3P2JaNuvo7/vmNlfM+Tj18AM4AHgfeC3QM9m5EEkFhp0JiIiKhmIiIiCgYiIoGAgIiIoGIiICAoGIiKCgoGIiKBgICIiKBiIiAgKBiIigoKBiIigYCAiIigYiIgICgYiIoKCgYiIoGAgIiIoGIiICAoGIiKCgoGIiKBgICIiKBiIiAgKBiIigoKBiIigYCAiIigYiIgICgYiIoKCgYiIoGAgIiIoGIiICAoGIiKCgoGIiKBgICIiKBiIiAjQttgZaK7evXv7oEGDip0NSajFixe/7e59Wvp9dV1LIS1evPg94Bl3H9dwW8kGg0GDBlFVVVXsbEhCmdlrxXhfXddSSGa2MlMgAFUTiYgICgYiIoKCgYiIUMJtBlJY27Zto6amhg8//LDYWSmojh07MmDAANq1a1fsrEgrkYRrvznXtYKBZFRTU0O3bt0YNGgQZlbs7BSEu/POO+9QU1PD4MGDG26uMLO5wAjAga8ApwLjgR3AeuDL7v6GhX+gm6PtW6L0vwKY2WTgP6Jzft/d7yn055L8lPq138R13ShVE0lGH374Ib169SrJL0O2zIxevXo1dge4P/CYux8CHA68BPzY3T/u7iOB3wPfi/Y9BRgaPaYAt0fn7wlcBXwCGANcZWY9CviRJAalfu03cV03SsFAGlWqX4ZcZPqMmzZtAugG3AXg7nXuvtHd30vbrQuhxAChtHCvB4uA7mbWDzgZWOjuG9z9XWAhkLFbn7QupX7tNyf/CgbSKm3cuJHbbrst5+NOPfVUNm7cmNd7r169GqAe+IWZvWBmd5pZFwAzm2Fma4AvsbNksB+wJu0UNVFaY+m7MLMpZlZlZlW1tbV55V3KT9euXWM5T1bBwMy6m9lcM1thZi+Z2SfN7GozW2tmS6LHqWn7TzezajN72cxOTksfF6VVm9m0tPTBZvZslP6gmbWP5dNJyWosGNTX1+/xuPnz59O9e/e83jt6j87A7e5+BLAZmAbg7t919/2B+4FL8nqjiLvPdPdKd6/s06fFBz2LANmXDG5m9/pTgBvdfWT0mA9gZsOAScBwQpH4NjOrMLMK4FZC/eow4AvRvgDXRef6GPAucGEMn01K2LRp03jllVcYOXIkRx55JMceeyynn346w4aFS+aMM85g9OjRDB8+nJkzZ3503KBBg3j77bd59dVXOfTQQ7nooosYPnw4Y8eO5YMPPsjqvQcMGABQ5+7PRklzgVENdrsfOCt6vpbQxvDRKaK0xtJFGjVt2jRuvfXWj15fffXVfP/73+fEE09k1KhRHHbYYTzyyCOxv2+TvYnMbG/gOODLEOpPgbo91EmNB2a7+1ZgtZlVExrPAKrdfVV03tnAeDN7CTgB+GK0zz3A1USNcFJ8l10GS5bEe86RI+Gmmxrf/qMf/YilS5eyZMkSnnrqKU477TSWLl36Ue+IWbNm0bNnTz744AOOPPJIzjrrLHr16rXLOVauXMmvfvUr7rjjDs455xweeughzj333Cbz1rdvXwjX+MHu/jJwIrDczIa6+8pot/HAiuj5POCS6Jr+BLDJ3d80swXAD9IajccC07P595FWoggX/8SJE7nsssuYOnUqAHPmzGHBggV8/etfZ6+99uLtt9/mqKOO4vTTT4+1bSObrqWDgVpC/enhwGLgG9G2S8zsfKAK+FbUSLYfsCjt+PR60ob1p58AegEb3b0+w/67MLMphN4aDBw4cLftmzfD00/DYYfBfhnPIKVqzJgxu3STu+WWW3j44YcBWLNmDStXrtwtGAwePJiRI0cCMHr0aF599dVc3vJ14P6oynIVcAFwp5kdTOha+hrwr9G+8wndSqsJXUsvAHD3DWZ2LfB8tN817r4hl0x8ZOVKqK6GU05p1uFSOo444gjWr1/PG2+8QW1tLT169KBv375885vf5C9/+Qtt2rRh7dq1rFu3LnXjEotsgkFbQhH5Und/1sxuJtSf/hS4ltCj4lrgJ4S+2AXj7jOBmQCVlZXecPtbb4Xvyj33wPnnFzIn5WVPd/AtpUuXLh89f+qpp3jiiSd45pln6Ny5M5/+9KczdqPr0KHDR88rKiqyriaKfODulQ3Szsq0o7s7MLWRbbOAWbm8cUb33AM//CHU10OJ93QpKUW6+CdMmMDcuXN56623mDhxIvfffz+1tbUsXryYdu3aMWjQoNgHxWXTZlAD1DSsP3X3de6+3d13AHewsyoo1/rTdwhd8do2SM9ZqlH9n/9sztHSmnTr1o33338/47ZNmzbRo0cPOnfuzIoVK1i0aFHG/RKla1fYsQNyC2hSoiZOnMjs2bOZO3cuEyZMYNOmTeyzzz60a9eOP/3pT7z2WvyT6jYZDNz9LWBNVDyGnfWn/dJ2OxNYGj2fB0wysw5mNpgwEOc5QlF5aNRzqD2hkXledFf1J+Ds6PjJQLNaRxQMkqNXr14cffTRjBgxgiuuuGKXbePGjaO+vp5DDz2UadOmcdRRRxUply2oW7fwVxd3WRg+fDjvv/8+++23H/369eNLX/oSVVVVHHbYYdx7770ccsghsb9nttNRXMru9ae3mNlIQjXRq8DFAO6+zMzmAMsJfbWnuvt2ADO7BFgAVACz3H1ZdP5vA7PN7PvAC0SDfXLVqVMoQev7kgwPPPBAxvQOHTrw6KOPZtyWahfo3bs3S5cu/Sj98ssvjz1/LSp1p/P++7DPPsXNi7SIF1988aPnvXv35plnnsm43z9j+sHLKhi4+xKgYf3peXvYfwYwI0P6fEJjW8P0VeysZmq2Nm2gSxcFA0kglQykwBI3ArlLl9CrSCRR0ksGIgWQuGDQtatuniSBVDKQAlMwkEaFtv1kK5nPqJJBiyqZ66IRzcm/goFk1LFjR955552S/1LsSWre944dOxY7K01TyaDFlPq139zrOnGL23TtCu++W+xclL4BAwZQU1ND0mfRTK0I1eqpZNBiknDtN+e6TmQwqKkpdi5KX7t27XJaJUkKTCWDFlOu137iqonUtVQSqX17aNtWJQMpmMQFA7UZSCKZhdKBLm4pEAUDkVLRtatKBlIwiQwGW7fCtm3FzolIzFQykAJKZDAAjUKWBFLJQApIwUCkVKhkIAWUuGCQWgNF3xlJnG7dVDKQgklcMNCaBpJY6h0hBaRgIFIqVDKQAlIwECkVakCWAlIwECkV3bqp37QUTOKCQaoBWb2JJHF0pyMFlLhgoO+LJJYmq5MCUjAQKRWaxloKKHHBIDW5o4KBJI5KBlJAiQsGZuqOLQmlkoEUUOKCASgYSEKpZCAFlMhg0KWLehNJAqlkIAWUyGCgkoEkkkoGUkAKBiKZVZjZXDNbYWYvmdknzezH0eu/m9nDZtY9tbOZTTezajN72cxOTksfF6VVm9m0vHKkkoEUkIKBSGb7A4+5+yHA4cBLwEJghLt/HPgHMB3AzIYBk4DhwDjgNjOrMLMK4FbgFGAY8IVo3+bRlLxSQAoGIg1s2rQJoBtwF4C717n7Rnd/3N3ro90WAQOi5+OB2e6+1d1XA9XAmOhR7e6r3L0OmB3t2zxt2oSAoJKBFEBig4EakKW5Vq9eDVAP/MLMXjCzO82sS4PdvgI8Gj3fD1iTtq0mSmssfRdmNsXMqsysqra2ds+Z0wI3UiCJDAZduuj7Is1XX18P0Bm43d2PADYDH9X3m9l3CcHi/jjez91nunulu1f26dNnzztrGmspkEQGg1Q1kXuxcyKlaMCAAQB17v5slDQXGAVgZl8GPgt8yf2jK2wtoY3ho1NEaY2lN5/qQKVAEhsMtm8Ps/2K5Kpv374AdWZ2cJR0IrDczMYBVwKnu/uWtEPmAZPMrIOZDQaGAs8BzwNDzWywmbUnNDLPyytzKhlIgbQtdgYKIX2yuo4di5sXKVmvA/dHP+KrgAsIP+4dgIVmBrDI3f/V3ZeZ2RxgOaH6aKq7bwcws0uABUAFMMvdl+WVq65dYd26vE4hkknig0Hv3sXNi5SsD9y9skHaxxrb2d1nADMypM8H5seWq27doLo6ttOJpCS2mgjUo0gSSEtfSoEkMhhobI4klrqWSoEkMhhogRtJLHWVkwJRMBApJd26hUCwZUvT+4rkQMFApJRosjopEAUDkVKiaaylQBIdDNSbSBJHJQMpkEQGA/UmksRSyUAKJJHBoKIijDzW90USRyUDKZBEBgPQfF6SUCoZSIEoGIiUEpUMpEASHQzUgCyJo5KBFEhig4EWuJFEUslACiSxwUDVRJJI7duHhy5uiZmCgUip0QI3UgAKBiKlRhe3FICCgUipUclACiDRwUC9iSSRdKcjBZBVMDCz7mY218xWmNlLZvZJM+tpZgvNbGX0t0e0r5nZLWZWbWZ/N7NRaeeZHO2/0swmp6WPNrMXo2NusWiB2Xx06RKCwY4d+Z5JpJVRyUAKINuSwc3AY+5+CHA48BIwDXjS3YcCT0avAU4BhkaPKcDtAGbWE7gK+AQwBrgqFUCifS5KO25cfh9rZw88TfsuiaOlL6UAmgwGZrY3cBxwF4C717n7RmA8cE+02z3AGdHz8cC9HiwCuptZP+BkYKG7b3D3d4GFwLho217uvsjdHbg37VzNpmmsJbG09KUUQDYlg8FALfALM3vBzO40sy7Avu7+ZrTPW8C+0fP9gDVpx9dEaXtKr8mQnhcFA0kslQykALIJBm2BUcDt7n4EsJmdVUIARHf0BV+U1cymmFmVmVXV1tbucV8FA0kslQykALIJBjVAjbs/G72eSwgO66IqHqK/66Pta4H9044fEKXtKX1AhvTduPtMd69098o+ffrsMdNa4EYSq2tXqKsLD5GYNBkM3P0tYI2ZHRwlnQgsB+YBqR5Bk4FHoufzgPOjXkVHAZui6qQFwFgz6xE1HI8FFkTb3jOzo6JeROennavZtMCNJJYmq5MCaJvlfpcC95tZe2AVcAEhkMwxswuB14Bzon3nA6cC1cCWaF/cfYOZXQs8H+13jbtviJ5/Dbgb6AQ8Gj3yomoiSaz0yep69ixuXiQxsgoG7r4EqMyw6cQM+zowtZHzzAJmZUivAkZkk5dsKRhIYqlkIAWQ6BHIoO+LJJCmsZYCUDAQKTUqGUgBJDYYdOoEZupNJM1WkWEKlglmtszMdpjZLtWmZjY9mk7lZTM7OS19XJRWbWbTdn+bZlDJQAog2wbkkmOm1c4kL/sDt7r72VHHic7ARuDzwM/TdzSzYcAkYDjQH3jCzA6KNt8KfIbQRft5M5vn7svzyplKBlIAiQ0GoMkdpXk2bdoE0I20KViAOkIwIMM8iuOB2e6+FVhtZtWE+bcAqt19VXTc7GjfeIKBSgYSo8RWE4GCgTTP6tWrAerZfQqWxuQ6BcsuchlZD6hBTApCwUCkgfr6egjVQo1OwRKnXEbWA9C5c6gHVclAYpT4YKAGZMnVgAEDAOoyTMHSmFynYMmPme50JHaJDwb6vkiu+vbtC1CXYQqWxswDJplZBzMbTFiT4znCaPuhZjY4aoSeFO2bPy1wIzFLdANyly7w+uvFzoWUqNdpMAWLmZ0J/DfQB/iDmS1x95PdfZmZzSEEjHpgqrtvBzCzSwjzclUAs9x9WSy5052OxCzRwUDfF8nDB+7ecAqWh6PHbtx9BjAjQ/p8wnxd8VLJQGKmaiKRUqSLW2KmYCBSilQykJglPhjU1cG2bcXOiUjMtPSlxCzRwSC1wI26l0riaOlLiVmig4EGakpiqWQgMVMwEClF3bqFIu+OHcXOiSSEgoFIKUpd3KoDlZgoGIiUIk1jLTEri2CgmydJHC1wIzFLdDBI9SbSzZMkjkoGErNEBwNVE0liqWQgMVMwEClFKhlIzBQMREqRlr6UmCU6GLRvD+3aKRhIAulOR2KW6GAAWu1MEkolA4lZ4oNBly66eZIEUslAYpb4YKBprCWR2raFjh1VMpDYKBiIlCpd3BIjBQORUqUFbiRGZREM1IAsiaQ7HYlRWQQDfV8kkVQykBglPhioN5Eklha4kRglPhioZCCJpaUvJUZlEwzci50TkZipZCAxKotgsGMHfPhhsXMiEjOVDCRGZREMQD2KJIFSJQMVeyUGZRMMdAMlidOtG9TXQ11dsXMiCZD4YKDVziSxtMCNxCjxwUAlA0ksLXAjMVIwEMmswszmmtkKM3vJzD5pZj3NbKGZrYz+9gCw4BYzqzazv5vZqNRJzGxytP9KM5scaw5VMpAYKRiIZLY/8Ji7HwIcDrwETAOedPehwJPRa4BTgKHRYwpwO4CZ9QSuAj4BjAGuSgWQWKhkIDEqm2Cg3kSSrU2bNgF0A+4CcPc6d98IjAfuiXa7Bzgjej4euNeDRUB3M+sHnAwsdPcN7v4usBAYF1tGtcCNxCjxwUANyJKr1atXA9QDvzCzF8zsTjPrAuzr7m9Gu70F7Bs93w9Yk3aKmiitsfRdmNkUM6sys6ra2trsM6pir8Qo8cFA3xfJVX19PUBn4HZ3PwLYzM4qIQDc3YFYOvi7+0x3r3T3yj59+mR/oEoGEqPEBwOVDCRXAwYMAKhz92ejpLnAKGBdVP1D9Hd9tH0toY3ho1NEaY2lx0N3OhKjxAeDigro1EnfF8le3759AerM7OAo6URgOTAPSPUImgw8Ej2fB5wf9So6CtgUVSctAMaaWY+o4XhslBYPlQwkRm2LnYGWoAVupBleB+43s/bAKuACws3THDO7EHgNOCfadz5wKlANbIn2xd03mNm1wPPRfte4+4bYctixI7RpozsdiUXZBAN9XyRHH7h7ZYb0ExsmRO0HUzOdxN1nAbNizltgpgVuJDaJryYCLXAjCaY7HYlJWQQDfV8ksVQykJgoGIiUMl3cEhMFA5FSppKBxCSrYGBmr5rZi2a2xMyqorSrzWxtlLbEzE5N2396NGnXy2Z2clr6uCit2sympaUPNrNno/QHox4csVFvIkksLX0pMcmlZHC8u49s0MPixihtpLvPBzCzYcAkYDhhHpbbzKzCzCqAWwmTeg0DvhDtC3BddK6PAe8CF+b3sXalkoEklpa+lJgUoppoPDDb3be6+2pC3+sx0aPa3Ve5ex0wGxhvZgacQBjlCbtOABYL9SaSxFLJQGKSbTBw4HEzW2xmU9LSL4nmb5+VNjVvrpN29QI2unt9g/TYdO0KW7bA9u1xnlWkFVDJQGKSbTA4xt1HEap4pprZcYQ52w8ERgJvAj8pTBZ3au7sjqkpXLZsKVDGRIpFdzoSk6yCgbuvjf6uBx4Gxrj7Onff7u47gDsI1UCQ+6Rd7xDmf2/bID1TPpo1u6Pm85LESs1PpB4Skqcmg4GZdTGzbqnnhMm2lqZmb4ycCSyNns8DJplZBzMbTFj96TnC/CxDo55D7QmNzPOiofx/As6Ojk+fACwWWuBGEkuT1UlMspmbaF/g4dDOS1vgAXd/zMzuM7ORhPaEV4GLAdx9mZnNIczyWA9MdfftAGZ2CWHWxgpglrsvi97j28BsM/s+8ALRClNxUclAEksXt8SkyWDg7qsIa8A2TD9vD8fMAGZkSJ9PmOEx03uMaZgeF61pIImlkoHEpGxGIIOCgSSQLm6JiYKBSClTyUBiomAgUsp0cUtMyioYqDeRJI5KBhKTsggGakCWxFLJQGJSFsGgU6ewQqC+L5I4qWCgkoHkqSyCgZlmLpWEqqgIdzu6uCVPZREMQMFAEkwL3EgMyioYqAFZEkl3OhKDsgoG+r5IIqlkIDEom2CgBW4ksbTAjcSgbIKBSgaSWFrgRmKgYCBS6lQykBgoGIiUOpUMJAZlFQzUm0gSSSUDiUFZBQPdPEkipUoG7sXOiZSwsgkGXbpAXV14iGThMDN70cyWmFkVgJkdbmbPROm/M7O9Ujub2XQzqzazl83s5LT0cVFatZlNK0hOu3WD7dvhww8LcnopD2UTDDRzqTTD8e4+0t0ro9d3AtPc/TDgYeAKADMbRljTezgwDrjNzCrMrAK4FTgFGAZ8Ido3XpqsTmJQdsFA3xfJw0HAX6LnC4GzoufjgdnuvtXdVwPVhGVcxwDV7r7K3euA2dG+8dI01hIDBQORxj1uZovNbEr0ehk7f8wnAPtHz/cD1qQdVxOlNZa+CzObYmZVZlZVW1ubey51cUsMyi4YqJpIsrTC3UcRqnimmtlxwFeAr5nZYqAbEEsLlLvPdPdKd6/s06dP7idQyUBi0LbYGWgpunmSHG0DcPf1ZvYwMMbdbwDGApjZQcBp0b5r2VlKABgQpbGH9Pjo4pYYlE3JQKudSbY2h+JjGwAz60IIAEvNbJ8orQ3wH8DPokPmAZPMrIOZDQaGAs8BzwNDzWywmbUnNDLPiz3DKhlIDFQyEGlg3bp1AIeY2d8I35EH3P0xM/uGmU2NdvsN8AsAd19mZnOA5UA9MNXdtwOY2SXAAqACmOXuy2LPsC5uiYGCgUgDQ4YMAVie1qUUAHe/Gbg50zHuPgOYkSF9PjC/ANncSSUDiUHZVBOpAVkSS3c6EoOyCQZqM5DE6tAB2rZVyUDyUjbBoH378FAwkMQx0+RbkreyCQag1c4kwbT0peSprIKBbp4ksXRxS54UDESSQCUDyVPZBQP1JpJE0gI3kqeyCwYqGUgiaelLyVNZBQM1IEtiqWQgeSqrYKCSgSSWSgaSJwUDkSRQA7LkScFAJAm6dg1rINfXFzsnUqLKLhhs3gzuxc6JSMxSk9XpbkeaqeyCwY4d4QZKJFE0WZ3kqayCgSark8TSNNaSp7IKBrp5ksTSxS15UjAQSQKVDCRPCgYiSaCLW/JUlsFA8xNJ4qhkIHkqy2CgmydJHF3ckqeyCgbqTSSJpZKB5KmsgoFuniSxdKcjeVIwEEmCNm1CQFDJQJqprIJB587hrxqQJZE0+ZbkoayCQUVFCAj6vkgiaeZSyUNZBQPQzZMkmBa4kTyUXTDQameSWFrgRvJQdsFAJQNJLFUTSR6yCgZm9qqZvWhmS8ysKkrraWYLzWxl9LdHlG5mdouZVZvZ381sVNp5Jkf7rzSzyWnpo6PzV0fHWtwfNEXBQBJLF7fkIZeSwfHuPtLdK6PX04An3X0o8GT0GuAUYGj0mALcDiF4AFcBnwDGAFelAki0z0Vpx41r9idqQmqBG5EmHJbhBmikmS1KpZnZmCg95xugglDJQPKQTzXReOCe6Pk9wBlp6fd6sAjobmb9gJOBhe6+wd3fBRYC46Jte7n7Ind34N60c8VON0+Sg4Y3QNcD/+nuI4HvRa+heTdA8dPFLXnINhg48LiZLTazKVHavu7+ZvT8LWDf6Pl+wJq0Y2uitD2l12RILwh9XyQPDuwVPd8beCN6ntMNUMFyl2pA1rqu0gxts9zvGHdfa2b7AAvNbEX6Rnd3Myv4FRgFoikAAwcObNY51JtIcvB4dF3/3N1nApcBC8zsBsKN1L9E++V6A7SLOK5rYOe6rh98sHOEZSlyh2eegU9+EgrXfCgNZFUycPe10d/1wMOEIu+66O6H6O/6aPe1wP5phw+I0vaUPiBDeqZ8zHT3Snev7NOnTzZZ341KBpKlFe4+ilAFNNXMjgP+Dfimu+8PfBO4K443iuO6BpIzWd3zz8PRR8OTTxY7J2WlyWBgZl3MrFvqOTAWWArMA1INYpOBR6Ln84Dzo0a1o4BNUXXSAmCsmfWI6k3HAguibe+Z2VFRL6Lz084Vu65dw43T9u2FegdJiG2w2w3QZOA30fZfR2mQ+w1QYSRl8q2VK3f9Ky0im5LBvsD/mNnfgOeAP7j7Y8CPgM+Y2UrgpOg1wHxgFVAN3AF8DcDdNwDXAs9Hj2uiNKJ97oyOeQV4NP+Pllnq+7JlS6HeQUrd5tDdrA3sdgP0BvCpaLcTgNSvVU43QAXLeFJKBq+/vutfaRFNthm4+yrg8Azp7wAnZkh3YGoj55oFzMqQXgWMyCK/eUu/eUp9d0TSrVu3DuCQ6AaoLfCAuz9mZv8EbjaztsCHRPX8hBugUwk3M1uACyDcAJlZ6gYIdr0Bil9SSgZromYWBYMWlW0DcmJo2ndpypAhQwCWp3UpBcDd/wcY3XD/5twAFURSSgYKBkVRltNRgIKBJFBSLm5VExWFgoFIUiStZLB2LdTXFzcvZaRsg4GmpJDESWjkNLMAAB5TSURBVMKdzubN8O67MHhw6PL3xhtNHyOxKNtgUMrfF5GMklAySJUKjjkm/FVVUYtRMBBJivbtoV270r64Uz/+Rx+962spuLILBupNJIlW6jOXpkoGCgYtruyCgUoGkmilvvTlmjVhPqKDDoKePRUMWlDZBYOOHaFNGwUDSahSX/ry9dehb99Q5XXAAQoGLajsgoGZFriRBEtCNdH+0XROAwcqGLSgsgsGoJlLJcFK/eJWMCiaspuOAkr/+yLSqG7d4B//gEceCWsadOq089Hwdbt2xc7trtzDj/+pp4bXAwfCpk3hsffexc1bGSjLYKAFbiSxBg6Ehx+GM7JYOfa88+Deewufp2xt2BDml08vGUAIEIcdVrx8lYmyDAYqGUhi3XADXHxx+FH94IMwV3vqefrrhQvhV7+Cm24KvXZag1S3UgWDoijbYFBbW+xciBRA27Zw6KFN7/epT8GCBaEUceGFhc9XNlLBIBUE0oOBFFzZNiCrN5GUtdGjYcgQePDBYudkp9SPfqpk0LdvaNdQMGgRZRsMVE0kZc0MJk6EP/6x9RST16wJP/777htet2kDAwYoGLQQBQORcjVxYpgZ9KGHip2TYM2a8OPfJu1nSd1LW0xZBgP1JhIBPv5xOPjg1lNV9PrrO6uIUjQKucWUZTDo2hW2bYO6umLnRKSIUlVFf/4zvPlmsXOz64CzlIEDtchNCynbYAAqHYgwcWIY7DV3bnHzsX17+NHPFAy2b28dwSrhyjoYqEeRlL1hw2DEiOJXFa1bF+7+U91JU9S9tMWUdTBQyUCEUDr43//d2c+/GBp2K01RMGgxCgYi5W7ixPA3n6qijRtDdU5zNRx9nJJ6/dprzT+3ZKUsg8E++4S/K1YUNx8ircLQoXDEEc2vKlq7FgYNgp/8pPl5aDj6OKVrVy1y00LKMhiMHg0HHgh33VXsnIi0EhMnwrPPwquv5n7sd74TZhZdtKj57//666HPd/fuu2/TWIMWUZbBoE0buOii0KPu5ZeLnRuRVuCcc8LfOXNyO+6558LMp23bwtKlzX//VLdSs923KRi0iLIMBgBf/nK4fu+4o9g5EWkFBg+GMWNyqypyh8suC9NHXHopvPJKmBG1Odas2b2KKEXBoEWUbTDYd98w5fvdd8PWrcXOjUgrMHEi/PWvUF2d3f6zZ8Mzz8APfgBHHQU7djS/IS7T6OOU9EVupGDKNhhAmPb9nXfgN78pdk5EWoEJE8LfbEoHW7bAlVeGhucvfzmMVQBYtiz39926NYwzaCwYHHBA+FvMrq9loKyDwQknhFl8Z84sdk6kFTrMzF40syVmVgVgZg9Gr5eY2atmtiS1s5lNN7NqM3vZzE5OSx8XpVWb2bRifJCs7b8/HH10dsHgxz+Gmhq4+ebQCDd0aJhxtDntBmvX7nz/TDTWoEWUdTBINSQ/9VRYNlakgePdfaS7VwK4+8To9UjgIeA3AGY2DJgEDAfGAbeZWYWZVQC3AqcAw4AvRPu2XhMnwosvwksvNb5PTQ1cd10oSRx7bEhr1y5MeteckkFj3UpTFAxaRFkHA1BDsuTOzAw4B/hVlDQemO3uW919NVANjIke1e6+yt3rgNnRvq3X2WeHHj17Kh1MmxbaB66/ftf04cObVzJobPRxiha5aRFlHwz69oXx49WQLBk9bmaLzWxKg/RjgXXuvjJ6vR+QXqFdE6U1lr4LM5tiZlVmVlVb7IVm+vULS2I++GDoLdTQokVw//3wrW+FgWbpRowI4xRyHdrf2OjjlNQiNxqFXFBlHwwApkyBt98Oy8GKRFa4+yhCFc9UMzsubdsX2FkqyJu7z3T3Snev7NOnT1ynbb6JE0OvoBdf3DV9x47QlbRfP5g+fffjhg8Pf5cvz+391qyBXr2gc+fG91H30oJTMABOOil0s1ZDsqTZBuDu64GHCVU+mFlb4PNAej3KWiD9tnZAlNZYeut21llQUbF7VdEDD4RRyj/84c4JvtI1t0fRnrqVpigYFJyCATsbkv/0JzUkC2wOc5u3ATCzLsBYIFUZfhKh1FCTdsg8YJKZdTCzwcBQ4DngeWComQ02s/aERuZ5LfMp8tCnT+hql15VtHlzaCuorITzzst83JAh0LFj7u0GmRa1aUiL3BScgkHkggtCQ/KddxY7J1Js69atAzjEzP5G+FH/g7s/Fm2eRIMqIndfBswBlgOPAVPdfbu71wOXAAuAl4A50b6t38SJYUTxX/8aXl9/ffgxvummXdcoTldRAYcemnvJYE+jj1O0yE3BKRhE+vaF00+HX/xCDcnlbsiQIQDL3f1wdx/u7jNS29z9y+7+s4bHuPsMdz/Q3Q9290fT0ue7+0HRthkNj2u1zjwz3B09+GConrn+epg0KYxD2JPhw3MLBu+/H6a/zqZkAKoqKiAFgzSphuTf/rbYOREpsp49YezYMHHdlVeGtOuua/q4ESPCOISNG7N7n6Z6EqWkRiErGBSMgkGaz3wm9JZTQ7IIoarotddC6eCKK5quyoHcexRlGwxS2xUMCkbBIE2qIfmPf4SVK5veXyTRxo+H9u2hf/+dpYOmpHoUZduI3NTo4xQtclNwCgYNXHBBaAdTQ7KUvb33Do1oc+Zk7kqaycCBYZGabNsNXn89jHju3z+7cysYFIyCQQP9+u1sSK6rK3ZuRIrsi19sutE4XZs2uU1LsWZN+NK1a9f0vgMHahRyASkYZDBlCtTWqiFZpFly6VGUTbfSFJUMCkrBIIPPfCZ0XlBDskgzjBgR1id4++2m981m9HGKFrkpKAWDDCoqQkPyk09mv+iTiERSPYqaKh24Zzf6OCVVgtAiNwWhYNAINSSLNFO2PYreeQc+/DC3aiJQVVGBKBg0on9/+NznNCJZJGf9+4eeSE2VDLIdY5CiYFBQCgZ7cOmlsH493HhjsXMiUkLMQumgqZJBU4vaNNS3b5giQ8GgIBQM9uCEE8IULddeq2pKkZykehRlWiAnJdeSQUVF2FfBoCAUDJpw443hev73fy92TkRKyIgRsGEDvPVW4/usWRNGOO+zT/bnVffSglEwaMIBB8B3vgNz58ITTxQ7NyIlIpseRa+/HpazbGxK7EwUDAom6/8FM6swsxfM7PfR67vNbLWZLYkeI6N0M7NbzKzazP5uZqPSzjHZzFZGj8lp6aPN7MXomFuiBcdbjcsvhwMPhEsu0ahkkaxk06Mol26lKQMHhllRtchN7HIpGXyDsEBHuivcfWT0WBKlnUJY6WkoMAW4HcDMegJXAZ8gLCF4lZn1iI65Hbgo7bhxzfgsBdOxI9xyC7z8cljbQ0SasM8+0Lv3nksGuYw+TtEiNwWTVTAwswHAaUA2ve7HA/d6sAjobmb9gJOBhe6+wd3fBRYC46Jte7n7Ind34F7gjOZ8mEI69dQwZ9E114QbExFpwp56FG3fHlZOa07JAFRVVADZlgxuAq4EdjRInxFVBd1oZh2itP2A9L43NVHantJrMqTvxsymmFmVmVXV1tZmmfX43HRTuIYvv7zF31qk9AwfHtY1yNSj6M03w5dJwaDVaDIYmNlngfXuvrjBpunAIcCRQE/g2/Fnb1fuPtPdK929sk+fPoV+u90MHhzWBH/wQfjTn1r87UVKy4gR8N57mYvS2a5j0JAWuSmYbEoGRwOnm9mrwGzgBDP7pbu/GVUFbQV+QWgHAFgLpIf7AVHantIHZEhvla68MgSFSy6BbduKnRuRVizVoyhTVVGuYwxSunWDHj0UDAqgyWDg7tPdfYC7DwImAX9093Ojun6inj9nAKn/8XnA+VGvoqOATe7+JrAAGGtmPaKG47HAgmjbe2Z2VHSu84FHYv6csenUCW6+OZR+b7ml2LkRacX21L0019HH6dS9tCDyGWdwv5m9CLwI9Aa+H6XPB1YB1cAdwNcA3H0DcC3wfPS4Jkoj2ufO6JhXgEfzyFfBfe5zcNppcPXV8MYbxc6NSCvVs2dYuKaxkkHXrmEOo1wpGBRE21x2dvengKei5yc0so8DUxvZNguYlSG9ChiRS16K7eabw43PFVfA/fcXOzcirVRjC92kupU2Z0jRAQfA00/nnzfZhUYgN9OBB4b2gwcegKeeKnZuRFqpESNCneqOBh0Rc1nUpqGBA2HjxtA4LbFRMMjDtGkwaJAak0UaNXw4bNkCr766a3pzRh+nqHtpQSgY5KFz5zD2YNky+OlPi50bkVYo07QUW7eGueFz7VaaomBQEAoGeTr9dDjlFPiP/4Df/77YuRFpZYYNC3/T2w1S4w5UMmhVFAzyZAZ33QWHHBICww037HkKd5Gystde4cc7vWSQT7dS0CI3BaJgEIN+/ULnhrPPDr2LLrhAS2WKfKRhj6Lmjj5OqagIU18rGMRKwSAmnTuHaSquvhruuSeskrZ+fbFzJXk4LJpWfYmZVaUSzexSM1thZsvM7Pq09OnRFOwvm9nJaenjorRqM5vW0h+iVRgxAl56aee006lgMGBA48c0RWMNYqdgECMzuOoqmDMHXngBjjwS/v73YudK8nB8ND17JYCZHU+Ylfdwdx8O3BClDyOMzh9OmH79tmj9jwrgVsK07sOAL0T7lpfhw8NCIK+8El6//nqY3rpTp+afU8EgdgoGBTBhQqg22r4d/uVf4Le/LXaOJCb/Bvwomo8Ld0+V/cYDs919q7uvJoykHxM9qt19lbvXEeb2Gl+EfBdXwx5F+XQrTUktcrN9e37nkY8oGBTI6NHw/POhM8WZZ8IPf6iG5RL0uJktNrMp0euDgGPN7Fkz+7OZHRml5zpt+y6KPTV7wR16aCg2p9oNmrOoTUMHHKBFbmKmYFBA/frBn/8MX/xiWEf5vPPgww+LnSvJ0gp3H0Wo4plqZscRpm/pCRwFXAHMiWOJ1mJPzV5wnTvDkCE7Swb5jD5OUffS2CkYFFinTvDLX8KMGWEOo09/WpPblYht8FFV0MOEKp8a4DfR1O3PERZ76k3u07aXn1SPovfeC4+WDAarVsHvfpff+5UBBYMWYBZKBr/5Tbg5OvLIUIUkrdPmzZsh+m6YWRfCdOtLgd8Cx0fpBwHtgbcJ07ZPMrMOZjaYsI73c4TZeYea2WAza09oZJ7Xsp+mlRgxAv7xj52NyPlWE6WCyWuvNb5PXR384AchEJ1+OixalN97JpyCQQs680x45hlo3x6OPVaznbZW69atAzjEzP5G+FH/g7s/Rphxd4iZLSU0Bk+OSgnLgDnAcuAxYKq7b3f3euASwloeLwFzon3Lz/DhoWvpk0+G1/mWDJpa5Obpp+GII+C73w0LmPfsGRrupFEKBi3ssMNCqeCoo+Dcc+Hb31aHiNZmyJAhAMvd/XB3H+7uMwDcvc7dz3X3Ee4+yt3/mDrG3We4+4HufrC7P5qWPt/dD4q2zWj5T9NKpHoUPRr90+QbDCBz99J33oGvfhWOOw42bw5zxDz0EFx6Kcybl3ltBQEUDIqid29YuBD+7d/g+utDCXbTpmLnSqSADj44jBx++mlo0wb698//nOnBwB3uvTfMC3P33WF++WXLwipUEIJBly7wox/l/74JpWBQJO3awW23we23w+OPh5LCypXFzpVIgXToAEOHhrne+/cPcwvlKxUMXn4ZTjoJJk+Gj30M/vpXuO668OOf0qsXXHwxzJ4Nq1fn/94tbdu2kP+HHirYWygYFNm//is88QTU1sKYMSEwiCRSak3kOKqIYOciNx//OCxeDD/7Gfzv/4bXmfz7v4fSyY9/HM/7t6SHHoKZM8MEaAWqW1YwaAU+9anQjrD//mE67Jtu0gA1SaBUu0FcwSD1o3/WWbBiRbhzbrOHn7T99gulh1mz4K234slDS3CHG28MJatU3fIpp4T2kRgpGLQSgwfD//0fjB8P3/xmuL4XLy52rkRilCoZ5NutNOXkk2HdurD2bN++2R1z5ZWhyuXGG+PJQ0t45hl47jm47LJQt3zXXWE0a2UlLFkS29soGLQiXbvC3LlhgNoTT4T/6+OPDx0iGi4hK1JyDjss/I0rGJjBPvvkdszHPgbnnBN+VN99N558FNp//VfoRjt5cnj9la+Ehvht28LkZw88EMvbKBi0Mm3ahAFqa9aEqs3qavjc58JN1cyZ8MEHxc6hSDMdfDDcd1+Yl6WYpk2Df/4Tbr21uPnIxurV8PDDMGXKrg3iY8aEqoMjj4QvfSm0h6SmCG8mBYNWau+94fLLw0j6++8P07tcfHGYn+s//zM0OIuUFLMwuKZ79+Lm4/DDQ5fTm28OYxFas//+73CHeMklu2/bd99QhfCNb4Rqr898Jq8fBgWDVq5duzDRXVUV/PGP4Ybg6qtDSfvii+HVV4udQ5ESNH06vP023HlnsXPSuPfeC/mbMKHxhYDatQs9Tu69N0y3MXp0sxsbFQxKhNnO9oPly0NJ+557woj7BQuKnTuREnP00WGU8g03hDmMWqNZs+D990OPkqacd17oVmsWPlszJuZTMChBhx4a2g9eeimUEE49NVzT6o4qkoPp08MCOb/8ZbFzsrvt2+GWW8IP+5FHNr0/wKhRoVQwfnyoCsuRgkEJS3VHPessuOKKUB27ZUuxcyVSIk4+ORStr7uuMBOEbdnS/EbdRx4JjcfZlArS9e4dFmNvRo8tBYMS16VL+L+fMQN+9Ss45hit9yGSFbNQOvjHP0KPnTht3Rr6hn/yk81rpL7xRhg0CM44I9587YGCQQKk1kv43e/CdPGVlfCXvxQ7VyIl4POfh4MOCusexFnP+tOfhnrcxYvh/PNzGyhUVQX/8z/w9a+H6TNaiIJBgpx2Whio2LMnnHhiGFejdgSRPaioCHP9vPBCfBOD1dbCtdeGKSN+8pOwqtX/+3/ZH3/jjWG9hgsvjCc/WVIwSJiDD4Znnw3VoVOnhrEqW7cWO1cirdi554aumz/4QTznu+qqMKjtJz8JU0hcdFE49333NX3s2rUwZ05Yk2GvveLJT5YUDBJo771D+9N3vxu6KZ9wQhjJLCIZtG8fRnj+5S+hR0Y+li6Fn/88TCh36KGhDvenPw2Ln3/1q6H755789KehSunrX88vH82gYJBQFRXw/e/Dr38d5rIaOjSMU7jvPvU4EtnNV78a1jy49trm1626w7e+Fe7or756Z3r79mHSsYEDw9q3jY0U3bw5BJIzzwyNxy1MwSDhzj47dJaYMSPMd3T++WGCxylTwoBFtSmIELrlTZ8Ojz0Wqnea49FHQ7vDVVeFwJKuV6/Qw6OuLkw29t57ux9/771h8rxcu5PGxd1L8jF69GiX3OzY4f7nP7tPnuzeubM7uB96qPuPf+z+1lvFzl3rAlS5ruvysn27+4QJ7mbu8+bldmxdnfvBB7sfdJD71q2N77dwoXtFhftpp7nX1+/63gcd5H7kkeGLWiB7uq5VMigjZmEE/t13w5tvwh13hDnDrrgirPsxfjz84hehBCFSdtq0CV+OUaPChGAvvpj9sT/7WVh+84YbQrVQY046KUw+94c/hC9eyqOPhiL8N78ZvqhFYF6i9QSVlZVeVVVV7GwkwooVIQjcd18IEhB6JZ10UpgI8dOfDo3S5cTMFrt7ZUu/r67rVmDt2jAFRPv2oa92U2smbNgQ1kkYPTpUE2XzY37ppaGxeObM0NvopJNCMFm1Kkw+VyB7uq5VMhAOOSSMyF+7Fv72t1BlOmRICBBnnBHGLXzyk/C974UOF611Xi+RWOy3H8ybF1ZR+/znm+6bfc01sGlTWIQm27v6G2+EsWPha18LcxA9+WSYprqAgaApKhlIo+rqwop7TzwRHs89F3q9deoUOkb079/4o1+/sF+pUslAmDMHJk4MvS7uvjvzD/3LL4e1nS+8MFQV5WLjxnCXtWJFWLCkpiasaFZAe7qu2xb0naWktW8Pn/pUeFx7bbh2n3oqlA7WrIE33gjdpt94I3NpoXdvOOqo0E5x3HGhKraINz4iuTnnnDClxNVXh6UGr7xy930uvzz8kF9zTe7n7949zEl/zDEh4BQ4EDRFwUCy1r17qDZqOHeWe6g2feONXR+vvBKCxe9/H/br3DncCKWCwyc+UdqlBykD3/teCAjTpoX61NNP37lt4cJwcV9/fe5rMacceCC89lqruEtSNZEU3FtvhfW7n346lCr+/vcQQNq1Cyu3VVaGbt7t24e09u13faTSunYNswYMHBgCUyE7XaiaSD7ywQeheLx8eRih/PGPh6mpjzgijOBcvhw6dCh2LrOiaiIpqr59w8p9EyaE1+++G0oMqeBw552hjS6Xqd+7dIH998/8GDQorPXQCm62JAk6dYLf/jbcuXzuc6Hx7Le/DVNPPPRQyQSCpigYSIvr0QM++9nwSOcO27aF9oe6ul2f19WFDhs1NWG9hjVrdj6WLg2lj/RCbkVFCApDh4Zef0OH7nwMGgRtdeVLLvr3DxN+HXtsGJCzalUoLZx5ZrFzFht9JaTVMNtZNZSrurrQNXbNmrBA1MqV4VFdHUoh77+/c9+2bUPJ4bOfDb0BG3GYmb0IbAfq3b3SzK4GLgJqo32+4+7zQ95tOnBhtP/X3X1BlD4OuBmoAO509x/l/umkVRg9OkwZMWFCuFhz6UpaAhQMJBHatw8/8IMHh8bpdO6wfv3OAJEKEp07N3na49397QZpN7r7DekJZjYMmAQMB/oDT5jZQdHmW4HPADXA82Y2z92XN+czSitw9tlw113h7mLUqGLnJlYKBpJ4ZrDvvuFxzDEFeYvxwGx33wqsNrNqYEy0rdrdV4V82OxoXwWDUvaVrxQ7BwWhEcgijXvczBab2ZS0tEvM7O9mNsvMUh3D9wPSZ3SqidIaS9+FmU0xsyozq6qtrW24WaRFKBiIZLbC3UcBpwBTzew44HbgQGAk8CbQzLmOd+XuM9290t0r+/TpE8cpRXKmYCCS2TYAd18PPAyMcfd17r7d3XcAd7CzKmgtsH/asQOitMbSRVodBQORBjZv3gzRd8PMugBjgaVm1i9ttzOBpdHzecAkM+tgZoOBocBzwPPAUDMbbGbtCY3M81rmU4jkRg3IIg2sW7cO4BAz+xvhO/KAuz9mZveZ2UjAgVeBiwHcfZmZzSE0DNcDU919O4CZXQIsIHQtneXuy1r684hkQ8FApIEhQ4YALG84bN/dz2vsGHefAczIkD4fmB93HkXilnU1kZlVmNkLZvb76PVgM3vWzKrN7MGoGExUVH4wSn/WzAalnWN6lP6ymZ2clj4uSqs2s2nxfTwREclGLm0G3wBeSnt9HWEAzseAdwmjL4n+vhul3xjt13BgzjjgtijAVBAG5pwCDAO+EO0rIiItJKtgYGYDgNOAO6PXBpwAzI12uQdITWw8PnpNtP3EaP+PBua4+2ogNTBnDNHAHHevA1IDc0REpIVkWzK4CbgS2BG97gVsdPfUPJPpg2k+GmgTbd8U7Z/XwBzQ4BwRkUJpMhiY2WeB9e6+uAXys0canCMiUhjZ9CY6GjjdzE4FOgJ7EWZh7G5mbaO7//TBNKmBNjVm1hbYG3iHPQ/A0cAcEZEiarJk4O7T3X2Auw8iNAD/0d2/BPwJODvabTLwSPR8XvSaaPsfPSynpoE5IiKtVE7LXprZp4HL3f2zZjaE0NjbE3gBONfdt5pZR+A+4AhgAzApbdbG7wJfIQzMuczdH43STyW0S6QG5uzWXztDXmqB1xrZ3BtoOPVwqdFnKK4D3L3F6yLL4LqGZHyOUv0MQ4Fn3H1cww0luwbynphZVTHWr42TPoM0lJR/zyR8jiR8hoY0N5GIiCgYiIhIcoPBzGJnIAb6DNJQUv49k/A5kvAZdpHINgMREclNUksGIiKSAwUDERFJVjBIylTYZvaqmb1oZkvMrKrY+clGtED8ejNbmpbW08wWmtnK6G+PPZ1DGpeEa7sUr2son2s7McEggVNhH+/uI0uoL/PdhKnJ000DnnT3ocCT0WvJUcKu7VK7rqFMru3EBAM0FXZRuftfCCPO06VPZ54+zbnkRtd2EZXLtZ2kYJD1VNglwIHHzWyxmU0pdmbysK+7vxk9fwvYt5iZKWFJubaTcl1DAq9trYHcOh3j7mvNbB9goZmtiO5OSpa7u5mpH3N5S9x1Dcm5tpNUMtjTFNklxd3XRn/XAw8TqglK0Toz6wcQ/V1f5PyUqkRc2wm6riGB13aSgkEipsI2sy5m1i31HBgLLN3zUa1W+nTm6dOcS25K/tpO2HUNCby2E1NN5O71ZnYJsICdU2EvK3K2mmNf4OGwbDRtgQfc/bHiZqlpZvYr4NNAbzOrAa4CfgTMMbMLCdMyn1O8HJauhFzbJXldQ/lc25qOQkREElVNJCIizaRgICIiCgYiIqJgICIiKBiIiAgKBiIigoKBiIgA/x9pNrRrz2SHvwAAAABJRU5ErkJggg==\n",
            "text/plain": [
              "<Figure size 432x432 with 2 Axes>"
            ]
          },
          "metadata": {
            "tags": [],
            "needs_background": "light"
          }
        }
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "cDa4nuQvjGSa"
      },
      "source": [
        "# VII. Test model"
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "AFCcFv4mnmRi",
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "outputId": "bf220347-2681-4466-dc70-060c0291b5cc"
      },
      "source": [
        "def test_scikit_ap(cat_preds, cat_labels):\n",
        "  ap = np.zeros(26, dtype=np.float32)\n",
        "  for i in range(26):\n",
        "    ap[i] = average_precision_score(cat_labels[i, :], cat_preds[i, :])\n",
        "  print ('ap', ap, ap.shape, ap.mean())\n",
        "  return ap.mean()\n",
        "\n",
        "\n",
        "def test_emotic_vad(cont_preds, cont_labels):\n",
        "  vad = np.zeros(3, dtype=np.float32)\n",
        "  for i in range(3):\n",
        "    vad[i] = np.mean(np.abs(cont_preds[i, :] - cont_labels[i, :]))\n",
        "  print ('vad', vad, vad.shape, vad.mean())\n",
        "  return vad.mean()\n",
        "\n",
        "\n",
        "def get_thresholds(cat_preds, cat_labels):\n",
        "  thresholds = np.zeros(26, dtype=np.float32)\n",
        "  for i in range(26):\n",
        "    p, r, t = precision_recall_curve(cat_labels[i, :], cat_preds[i, :])\n",
        "    for k in range(len(p)):\n",
        "      if p[k] == r[k]:\n",
        "        thresholds[i] = t[k]\n",
        "        break\n",
        "  np.save('./thresholds.npy', thresholds)\n",
        "  return thresholds\n",
        "\n",
        "print ('completed cell')"
      ],
      "execution_count": 14,
      "outputs": [
        {
          "output_type": "stream",
          "text": [
            "completed cell\n"
          ],
          "name": "stdout"
        }
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "KOeZRVdbUPNx",
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "outputId": "e20ad71b-9d42-47f5-cda8-0bd08abb27c4"
      },
      "source": [
        "def test_data(models, device, data_loader, num_images):\n",
        "    model_context, model_body, emotic_model = models\n",
        "    cat_preds = np.zeros((num_images, 26))\n",
        "    cat_labels = np.zeros((num_images, 26))\n",
        "    cont_preds = np.zeros((num_images, 3))\n",
        "    cont_labels = np.zeros((num_images, 3))\n",
        "\n",
        "    with torch.no_grad():\n",
        "        model_context.to(device)\n",
        "        model_body.to(device)\n",
        "        emotic_model.to(device)\n",
        "        model_context.eval()\n",
        "        model_body.eval()\n",
        "        emotic_model.eval()\n",
        "        indx = 0\n",
        "        print ('starting testing')\n",
        "        for images_context, images_body, labels_cat, labels_cont in iter(data_loader):\n",
        "            images_context = images_context.to(device)\n",
        "            images_body = images_body.to(device)\n",
        "\n",
        "            pred_context = model_context(images_context)\n",
        "            pred_body = model_body(images_body)\n",
        "            pred_cat, pred_cont = emotic_model(pred_context, pred_body)\n",
        "\n",
        "            cat_preds[ indx : (indx + pred_cat.shape[0]), :] = pred_cat.to(\"cpu\").data.numpy()\n",
        "            cat_labels[ indx : (indx + labels_cat.shape[0]), :] = labels_cat.to(\"cpu\").data.numpy()\n",
        "            cont_preds[ indx : (indx + pred_cont.shape[0]), :] = pred_cont.to(\"cpu\").data.numpy() * 10\n",
        "            cont_labels[ indx : (indx + labels_cont.shape[0]), :] = labels_cont.to(\"cpu\").data.numpy() * 10 \n",
        "            indx = indx + pred_cat.shape[0]\n",
        "\n",
        "    cat_preds = cat_preds.transpose()\n",
        "    cat_labels = cat_labels.transpose()\n",
        "    cont_preds = cont_preds.transpose()\n",
        "    cont_labels = cont_labels.transpose()\n",
        "    scipy.io.savemat('./cat_preds.mat',mdict={'cat_preds':cat_preds})\n",
        "    scipy.io.savemat('./cat_labels.mat',mdict={'cat_labels':cat_labels})\n",
        "    scipy.io.savemat('./cont_preds.mat',mdict={'cont_preds':cont_preds})\n",
        "    scipy.io.savemat('./cont_labels.mat',mdict={'cont_labels':cont_labels})\n",
        "    print ('completed testing')\n",
        "    ap_mean = test_scikit_ap(cat_preds, cat_labels)\n",
        "    vad_mean = test_emotic_vad(cont_preds, cont_labels)\n",
        "    print (ap_mean, vad_mean)\n",
        "    return ap_mean, vad_mean \n",
        "\n",
        "print ('completed cell')"
      ],
      "execution_count": 15,
      "outputs": [
        {
          "output_type": "stream",
          "text": [
            "completed cell\n"
          ],
          "name": "stdout"
        }
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "qIUQLrXBZ2RR",
        "outputId": "c958d8ba-6e32-438f-c5c5-816d9b9ed829"
      },
      "source": [
        "model_context = torch.load('./models/model_context1.pth')\n",
        "model_body = torch.load('./models/model_body1.pth')\n",
        "emotic_model = torch.load('./models/model_emotic1.pth')\n",
        "\n",
        "print ('completed cell')"
      ],
      "execution_count": 16,
      "outputs": [
        {
          "output_type": "stream",
          "text": [
            "completed cell\n"
          ],
          "name": "stdout"
        }
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "oB69Xo-kLldG",
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "outputId": "b6be064a-25b2-43d3-e7e7-51a7fc8a9304"
      },
      "source": [
        "val_ap, val_vad = test_data([model_context, model_body, emotic_model], device, val_loader, val_dataset.__len__())\n",
        "test_ap, test_vad = test_data([model_context, model_body, emotic_model], device, test_loader, test_dataset.__len__())\n",
        "\n",
        "print ('validation Mean average precision=%.4f Mean VAD MAE=%.4f' %(val_ap, val_vad))\n",
        "print ('testing Mean average precision=%.4f Mean VAD MAE=%.4f' %(test_ap, test_vad))"
      ],
      "execution_count": 17,
      "outputs": [
        {
          "output_type": "stream",
          "text": [
            "starting testing\n",
            "completed testing\n",
            "ap [0.3983917  0.18015468 0.22337271 0.95204633 0.17163357 0.7866947\n",
            " 0.23361506 0.37178904 0.19096893 0.20868655 0.06009851 0.98069084\n",
            " 0.26645675 0.7951143  0.13405906 0.08186857 0.8081806  0.16670538\n",
            " 0.29040682 0.49211633 0.20419936 0.08260126 0.18704712 0.14419095\n",
            " 0.3501988  0.11717057] (26,) 0.34147915\n",
            "vad [0.70697206 0.8584789  0.86687875] (3,) 0.81077653\n",
            "0.34147915 0.81077653\n",
            "starting testing\n",
            "completed testing\n",
            "ap [0.29003292 0.08763415 0.14132965 0.56043494 0.07053518 0.75399864\n",
            " 0.11882206 0.2385993  0.16040386 0.173684   0.01993784 0.86009395\n",
            " 0.15641297 0.69662005 0.09915597 0.06025878 0.66563565 0.06506737\n",
            " 0.21911173 0.4214436  0.17897978 0.05904196 0.1752331  0.08228464\n",
            " 0.13343503 0.0820521 ] (26,) 0.2527015\n",
            "vad [0.8996919 1.0314642 0.943558 ] (3,) 0.95823807\n",
            "0.2527015 0.95823807\n",
            "validation Mean average precision=0.3415 Mean VAD MAE=0.8108\n",
            "testing Mean average precision=0.2527 Mean VAD MAE=0.9582\n"
          ],
          "name": "stdout"
        }
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "T-fc5LNp4len",
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "outputId": "8de112fa-a4bd-43c0-ff44-895b1ae32fe1"
      },
      "source": [
        "cat_labels = scipy.io.loadmat('./cat_labels.mat')\n",
        "cat_preds = scipy.io.loadmat('./cat_preds.mat')\n",
        "cat_preds = cat_preds['cat_preds']\n",
        "cat_labels = cat_labels['cat_labels']\n",
        "print (cat_preds.shape, cat_labels.shape)\n",
        "\n",
        "#thesholds calculation for inference \n",
        "thresholds = get_thresholds(cat_preds, cat_labels)\n",
        "print (thresholds, thresholds.shape)\n",
        "\n",
        "print ('completed cell')"
      ],
      "execution_count": 18,
      "outputs": [
        {
          "output_type": "stream",
          "text": [
            "(26, 7203) (26, 7203)\n",
            "[0.11334415 0.32935348 0.17811956 0.1820814  0.24816841 0.13238849\n",
            " 0.23765785 0.10895684 0.07811652 0.07971309 0.14207679 0.47783324\n",
            " 0.08085962 0.14741261 0.12622227 0.12906708 0.22126663 0.2721243\n",
            " 0.10970519 0.10124312 0.18777776 0.14807722 0.2636854  0.09791826\n",
            " 0.0983988  0.0875175 ] (26,)\n",
            "completed cell\n"
          ],
          "name": "stdout"
        }
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "owTpkHmOjLvr"
      },
      "source": [
        "# VIII. Average Precision computation using <a href=\"https://1drv.ms/u/s!AkYHbdGNmIVCgbYZB_dY3wuWJou_5A?e=jcsZUj\">author's script</a>"
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "30PEDPHxrkXA",
        "colab": {
          "base_uri": "https://localhost:8080/",
          "height": 101
        },
        "outputId": "8d2ed78c-fadb-40fc-8f11-be409beb8ea0"
      },
      "source": [
        "!apt install octave"
      ],
      "execution_count": null,
      "outputs": [
        {
          "output_type": "stream",
          "text": [
            "Reading package lists... Done\n",
            "Building dependency tree       \n",
            "Reading state information... Done\n",
            "octave is already the newest version (4.2.2-1ubuntu1).\n",
            "0 upgraded, 0 newly installed, 0 to remove and 31 not upgraded.\n"
          ],
          "name": "stdout"
        }
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "6fWR4CTMr7Hf",
        "colab": {
          "base_uri": "https://localhost:8080/",
          "height": 34
        },
        "outputId": "b7539f27-3a07-4184-f67f-d4b3d84350f7"
      },
      "source": [
        "%%writefile eval.m\n",
        "\n",
        "gt = load('./cat_labels.mat')\n",
        "gt = gt.cat_labels\n",
        "\n",
        "pred = load('./cat_preds.mat')\n",
        "pred = pred.cat_preds\n",
        "\n",
        "categories{1} = 'Affection';\n",
        "categories{2} = 'Anger';\n",
        "categories{3} = 'Annoyance';\n",
        "categories{4} = 'Anticipation';\n",
        "categories{5} = 'Aversion';\n",
        "categories{6} = 'Confidence';\n",
        "categories{7} = 'Disapproval';\n",
        "categories{8} = 'Disconnection';\n",
        "categories{9} = 'Disquietment';\n",
        "categories{10} = 'Doubt/Confusion';\n",
        "categories{11} = 'Embarrassment';\n",
        "categories{12} = 'Engagement';\n",
        "categories{13} = 'Esteem';\n",
        "categories{14} = 'Excitement';\n",
        "categories{15} = 'Fatigue';\n",
        "categories{16} = 'Fear';\n",
        "categories{17} = 'Happiness';\n",
        "categories{18} = 'Pain';\n",
        "categories{19} = 'Peace';\n",
        "categories{20} = 'Pleasure';\n",
        "categories{21} = 'Sadness';\n",
        "categories{22} = 'Sensitivity';\n",
        "categories{23} = 'Suffering';\n",
        "categories{24} = 'Surprise';\n",
        "categories{25} = 'Sympathy';\n",
        "categories{26} = 'Yearning';\n",
        "\n",
        "\n",
        "for c = 1:length(categories)\n",
        "  confidence = pred(c,:)'; \n",
        "  testClass = gt(c,:)';\n",
        "  confidence = double(confidence);\n",
        "\n",
        "  S = rand('state');\n",
        "  rand('state',0);\n",
        "  confidence = confidence + rand(size(confidence))*10^(-10);\n",
        "  rand('state',S)\n",
        "\n",
        "  [S,j] = sort(-confidence);\n",
        "  C = testClass(j);\n",
        "  n = length(C);\n",
        "    \n",
        "  REL = sum(C);\n",
        "  if n>0\n",
        "    RETREL = cumsum(C);\n",
        "    RET    = (1:n)';\n",
        "  else\n",
        "    RETREL = 0;\n",
        "    RET    = 1;\n",
        "  end\n",
        "\n",
        "  precision = 100*RETREL ./ RET;\n",
        "  recall    = 100*RETREL  / REL;\n",
        "  th = -S;\n",
        "\n",
        "  % compute AP\n",
        "  mrec=[0 ; recall ; 100];\n",
        "  mpre=[0 ; precision ; 0];\n",
        "  for i=numel(mpre)-1:-1:1\n",
        "    mpre(i)=max(mpre(i),mpre(i+1));\n",
        "  end\n",
        "  i=find(mrec(2:end)~=mrec(1:end-1))+1;\n",
        "  averagePrecision=sum((mrec(i)-mrec(i-1)).*mpre(i))/100;\n",
        "  ap_list(c)  = averagePrecision\n",
        "end\n",
        "\n",
        "display('#######################################')\n",
        "\n",
        "display('Average precision of predictions');\n",
        "for c = 1:length(categories)\n",
        "    sp = '............................';\n",
        "    cat = strcat(categories{c}, sp);\n",
        "    cat = cat(1:18);\n",
        "    display(cat);\n",
        "    display(ap_list(c));\n",
        "end"
      ],
      "execution_count": null,
      "outputs": [
        {
          "output_type": "stream",
          "text": [
            "Overwriting eval.m\n"
          ],
          "name": "stdout"
        }
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "fA1Oc48zvI_l"
      },
      "source": [
        "!octave -W eval.m"
      ],
      "execution_count": null,
      "outputs": []
    }
  ]
}